diff --git a/Cargo.lock b/Cargo.lock index 01ffa32..71fd459 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -180,6 +180,28 @@ dependencies = [ "syn", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -439,7 +461,7 @@ dependencies = [ "humantime", "ignore", "im-rc", - "indexmap", + "indexmap 2.8.0", "itertools", "jobserver", "lazycell", @@ -516,7 +538,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b489cbdae63be32c040b5fe81b0f7725e563bcd805bb828e746971a4967aaf28" dependencies = [ "cargo-credential", - "security-framework", + "security-framework 3.2.0", ] [[package]] @@ -545,7 +567,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "527f6e2a4e80492e90628052be879a5996c2453ad5ec745bfa310a80b7eca20a" dependencies = [ "anyhow", - "core-foundation", + "core-foundation 0.10.0", "filetime", "hex", "ignore", @@ -713,6 +735,16 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation" version = "0.10.0" @@ -2274,6 +2306,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.14.5" @@ -2452,6 +2490,22 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.10" @@ -2670,6 +2724,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", + "serde", +] + [[package]] name = "indexmap" version = "2.8.0" @@ -3059,6 +3124,23 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + [[package]] name = "ndarray" version = "0.16.1" @@ -3152,6 +3234,23 @@ dependencies = [ "memchr", ] +[[package]] +name = "ollama-rs" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5df54edb7e1264719be607cd40590d3769b5b35a2623e6e02681e6591aea5b8" +dependencies = [ + "async-stream", + "log", + "reqwest", + "schemars", + "serde", + "serde_json", + "static_assertions", + "thiserror 2.0.12", + "url", +] + [[package]] name = "once_cell" version = "1.21.2" @@ -3716,12 +3815,14 @@ dependencies = [ "http-body-util", "hyper", "hyper-rustls", + "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", "mime_guess", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -3735,6 +3836,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-native-tls", "tokio-rustls", "tokio-util", "tower", @@ -3863,7 +3965,7 @@ checksum = "781442f29170c5c93b7185ad559492601acdc71d5bb0706f5868094f45cfcd08" [[package]] name = "rustdocs_mcp_server" -version = "1.1.0" +version = "1.3.1" dependencies = [ "anyhow", "async-openai", @@ -3874,6 +3976,7 @@ dependencies = [ "dotenvy", "futures", "ndarray", + "ollama-rs", "rmcp", "schemars", "scraper", @@ -3883,6 +3986,7 @@ dependencies = [ "thiserror 2.0.12", "tiktoken-rs", "tokio", + "url", "walkdir", "xdg", ] @@ -3948,7 +4052,7 @@ dependencies = [ "openssl-probe", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.2.0", ] [[package]] @@ -4017,6 +4121,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" dependencies = [ "dyn-clone", + "indexmap 1.9.3", "schemars_derive", "serde", "serde_json", @@ -4079,6 +4184,19 @@ dependencies = [ "zeroize", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + [[package]] name = "security-framework" version = "3.2.0" @@ -4086,7 +4204,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" dependencies = [ "bitflags", - "core-foundation", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -4644,6 +4762,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.2" @@ -4705,7 +4833,7 @@ version = "0.22.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" dependencies = [ - "indexmap", + "indexmap 2.8.0", "serde", "serde_spanned", "toml_datetime", diff --git a/Cargo.toml b/Cargo.toml index 3d3f4d1..7abf9a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "1.3.1" edition = "2024" [dependencies] -rmcp = { version = "0.1.5", features = ["tower", "transport-io", "transport-sse-server", "macros", "server"] } # Add macros, server, schemars +rmcp = { version = "0.1.5", features = ["tower", "transport-io", "transport-sse-server", "macros", "server"] } tokio = { version = "1", features = ["macros", "rt-multi-thread"] } dotenvy = "0.15" serde = { version = "1", features = ["derive"] } @@ -12,20 +12,19 @@ serde_json = "1" thiserror = "2.0.12" walkdir = "2.5.0" scraper = "0.23.1" -ndarray = { version = "0.16.1", features = ["serde"] } # Enable serde feature -async-openai = "0.28.0" -# async-trait = "0.1.88" # Removed, likely no longer needed +ndarray = { version = "0.16.1", features = ["serde"] } +async-openai = "0.28.0" # Keep for chat completion (optional) +ollama-rs = "0.2.0" # Add Ollama client +url = "2.4" # Add url parsing for Ollama client construction futures = "0.3" -bincode = { version = "2.0.1", features = ["serde"] } # Enable serde integration +bincode = { version = "2.0.1", features = ["serde"] } tiktoken-rs = "0.6.0" -# Configure cargo crate to vendor openssl to avoid system mismatches cargo = { version = "0.87.1", default-features = false, features = ["vendored-openssl"] } tempfile = "3.19.1" anyhow = "1.0.97" schemars = "0.8.22" clap = { version = "4.5.34", features = ["cargo", "derive", "env"] } - # --- Platform Specific Dependencies --- [target.'cfg(not(target_os = "windows"))'.dependencies] @@ -34,12 +33,10 @@ xdg = { version = "2.5.2", features = ["serde"] } [target.'cfg(target_os = "windows")'.dependencies] dirs = "6.0.0" - # Optimize release builds for size [profile.release] -opt-level = "z" # Optimize for size -lto = true # Enable Link Time Optimization -codegen-units = 1 # Maximize size reduction opportunities -panic = "abort" # Abort on panic to remove unwinding code -strip = true # Strip symbols from binary - +opt-level = "z" +lto = true +codegen-units = 1 +panic = "abort" +strip = true \ No newline at end of file diff --git a/src/embeddings.rs b/src/embeddings.rs index 4080a61..ee76cb8 100644 --- a/src/embeddings.rs +++ b/src/embeddings.rs @@ -4,14 +4,18 @@ use async_openai::{ Client as OpenAIClient, }; use ndarray::{Array1, ArrayView1}; +use ollama_rs::{ + generation::embeddings::request::GenerateEmbeddingsRequest, + Ollama, +}; use std::sync::OnceLock; use std::sync::Arc; use tiktoken_rs::cl100k_base; use futures::stream::{self, StreamExt}; -// Static OnceLock for the OpenAI client +// Static OnceLocks for both clients pub static OPENAI_CLIENT: OnceLock> = OnceLock::new(); - +pub static OLLAMA_CLIENT: OnceLock = OnceLock::new(); use bincode::{Encode, Decode}; use serde::{Serialize, Deserialize}; @@ -20,11 +24,10 @@ use serde::{Serialize, Deserialize}; #[derive(Serialize, Deserialize, Debug, Encode, Decode)] pub struct CachedDocumentEmbedding { pub path: String, - pub content: String, // Add the extracted document content + pub content: String, pub vector: Vec, } - /// Calculates the cosine similarity between two vectors. pub fn cosine_similarity(v1: ArrayView1, v2: ArrayView1) -> f32 { let dot_product = v1.dot(&v2); @@ -38,60 +41,160 @@ pub fn cosine_similarity(v1: ArrayView1, v2: ArrayView1) -> f32 { } } -/// Generates embeddings for a list of documents using the OpenAI API. -pub async fn generate_embeddings( - client: &OpenAIClient, +/// Generates embeddings using Ollama with the nomic-embed-text model +pub async fn generate_ollama_embeddings( + ollama_client: &Ollama, documents: &[Document], model: &str, -) -> Result<(Vec<(String, Array1)>, usize), ServerError> { // Return tuple: (embeddings, total_tokens) - // eprintln!("Generating embeddings for {} documents...", documents.len()); +) -> Result)>, ServerError> { + eprintln!("Generating embeddings for {} documents using Ollama...", documents.len()); + + const CONCURRENCY_LIMIT: usize = 4; // Lower concurrency for Ollama + const TOKEN_LIMIT: usize = 8000; // Adjust based on your model's limits - // Get the tokenizer for the model and wrap in Arc + // Get the tokenizer (we'll use this for approximate token counting) let bpe = Arc::new(cl100k_base().map_err(|e| ServerError::Tiktoken(e.to_string()))?); - const CONCURRENCY_LIMIT: usize = 8; // Number of concurrent requests - const TOKEN_LIMIT: usize = 8000; // Keep a buffer below the 8192 limit + let results = stream::iter(documents.iter().enumerate()) + .map(|(index, doc)| { + let ollama_client = ollama_client.clone(); + let model = model.to_string(); + let doc = doc.clone(); + let bpe = Arc::clone(&bpe); + + async move { + // Approximate token count for filtering + let token_count = bpe.encode_with_special_tokens(&doc.content).len(); + + if token_count > TOKEN_LIMIT { + eprintln!( + " Skipping document {}: Approximate tokens ({}) exceed limit ({}). Path: {}", + index + 1, + token_count, + TOKEN_LIMIT, + doc.path + ); + return Ok::)>, ServerError>(None); + } + + eprintln!( + " Processing document {} (approx {} tokens)... Path: {}", + index + 1, + token_count, + doc.path + ); + + // Create embeddings request for Ollama + let request = GenerateEmbeddingsRequest::new( + model, + doc.content.clone().into(), + ); + + match ollama_client.generate_embeddings(request).await { + Ok(response) => { + if let Some(embedding) = response.embeddings.first() { + let embedding_array = Array1::from(embedding.clone()); + eprintln!(" Received response for document {}.", index + 1); + Ok(Some((doc.path.clone(), embedding_array))) + } else { + Err(ServerError::Config(format!( + "No embeddings returned for document {}", + index + 1 + ))) + } + } + Err(e) => Err(ServerError::Config(format!( + "Ollama embedding error for document {}: {}", + index + 1, e + ))) + } + } + }) + .buffer_unordered(CONCURRENCY_LIMIT) + .collect::)>, ServerError>>>() + .await; + + // Process collected results + let mut embeddings_vec = Vec::new(); + for result in results { + match result { + Ok(Some((path, embedding))) => { + embeddings_vec.push((path, embedding)); + } + Ok(None) => {} // Skipped document + Err(e) => { + eprintln!("Error during Ollama embedding generation: {}", e); + return Err(e); + } + } + } + + eprintln!( + "Finished generating Ollama embeddings. Successfully processed {} documents.", + embeddings_vec.len() + ); + Ok(embeddings_vec) +} + +/// Generates embeddings for a single text using Ollama (for questions) +pub async fn generate_single_ollama_embedding( + ollama_client: &Ollama, + text: &str, + model: &str, +) -> Result, ServerError> { + let request = GenerateEmbeddingsRequest::new( + model.to_string(), + text.to_string().into(), + ); + + match ollama_client.generate_embeddings(request).await { + Ok(response) => { + if let Some(embedding) = response.embeddings.first() { + Ok(Array1::from(embedding.clone())) + } else { + Err(ServerError::Config("No embedding returned".to_string())) + } + } + Err(e) => Err(ServerError::Config(format!( + "Ollama embedding error: {}", + e + ))) + } +} + +/// Legacy OpenAI embedding generation (kept for fallback) +pub async fn generate_openai_embeddings( + client: &OpenAIClient, + documents: &[Document], + model: &str, +) -> Result<(Vec<(String, Array1)>, usize), ServerError> { + // Keep the original OpenAI implementation for fallback + let bpe = Arc::new(cl100k_base().map_err(|e| ServerError::Tiktoken(e.to_string()))?); + + const CONCURRENCY_LIMIT: usize = 8; + const TOKEN_LIMIT: usize = 8000; let results = stream::iter(documents.iter().enumerate()) .map(|(index, doc)| { - // Clone client, model, doc, and Arc for the async block let client = client.clone(); let model = model.to_string(); let doc = doc.clone(); - let bpe = Arc::clone(&bpe); // Clone the Arc pointer + let bpe = Arc::clone(&bpe); async move { - // Calculate token count for this document let token_count = bpe.encode_with_special_tokens(&doc.content).len(); if token_count > TOKEN_LIMIT { - // eprintln!( - // " Skipping document {}: Actual tokens ({}) exceed limit ({}). Path: {}", - // index + 1, - // token_count, - // TOKEN_LIMIT, - // doc.path - // ); - // Return Ok(None) to indicate skipping, with 0 tokens processed for this doc - return Ok::, usize)>, ServerError>(None); // Include token count type + return Ok::, usize)>, ServerError>(None); } - // Prepare input for this single document let inputs: Vec = vec![doc.content.clone()]; - let request = CreateEmbeddingRequestArgs::default() - .model(&model) // Use cloned model string + .model(&model) .input(inputs) - .build()?; // Propagates OpenAIError + .build()?; - // eprintln!( - // " Sending request for document {} ({} tokens)... Path: {}", - // index + 1, - // token_count, // Use correct variable name - // doc.path - // ); - let response = client.embeddings().create(request).await?; // Propagates OpenAIError - // eprintln!(" Received response for document {}.", index + 1); + let response = client.embeddings().create(request).await?; if response.data.len() != 1 { return Err(ServerError::OpenAI( @@ -107,39 +210,55 @@ pub async fn generate_embeddings( )); } - // Process result - let embedding_data = response.data.first().unwrap(); // Safe unwrap due to check above + let embedding_data = response.data.first().unwrap(); let embedding_array = Array1::from(embedding_data.embedding.clone()); - // Return Ok(Some(...)) for successful embedding, include token count - Ok(Some((doc.path.clone(), embedding_array, token_count))) // Include token count + Ok(Some((doc.path.clone(), embedding_array, token_count))) } }) - .buffer_unordered(CONCURRENCY_LIMIT) // Run up to CONCURRENCY_LIMIT futures concurrently - .collect::, usize)>, ServerError>>>() // Update collected result type + .buffer_unordered(CONCURRENCY_LIMIT) + .collect::, usize)>, ServerError>>>() .await; - // Process collected results, filtering out errors and skipped documents, summing tokens let mut embeddings_vec = Vec::new(); let mut total_processed_tokens: usize = 0; for result in results { match result { Ok(Some((path, embedding, tokens))) => { - embeddings_vec.push((path, embedding)); // Keep successful embeddings - total_processed_tokens += tokens; // Add tokens for successful ones + embeddings_vec.push((path, embedding)); + total_processed_tokens += tokens; } - Ok(None) => {} // Ignore skipped documents + Ok(None) => {} Err(e) => { - // Log error but potentially continue? Or return the first error? - // For now, let's return the first error encountered. - eprintln!("Error during concurrent embedding generation: {}", e); + eprintln!("Error during OpenAI embedding generation: {}", e); return Err(e); } } } eprintln!( - "Finished generating embeddings. Successfully processed {} documents ({} tokens).", + "Finished generating OpenAI embeddings. Successfully processed {} documents ({} tokens).", embeddings_vec.len(), total_processed_tokens ); - Ok((embeddings_vec, total_processed_tokens)) // Return tuple + Ok((embeddings_vec, total_processed_tokens)) +} + +/// Main embedding generation function that tries Ollama first, falls back to OpenAI +pub async fn generate_embeddings( + documents: &[Document], + model: &str, +) -> Result<(Vec<(String, Array1)>, usize), ServerError> { + // Check if Ollama is available + if let Some(ollama_client) = OLLAMA_CLIENT.get() { + eprintln!("Using Ollama for embedding generation with model: {}", model); + // For Ollama, we don't track tokens the same way, so return 0 for token count + let embeddings = generate_ollama_embeddings(ollama_client, documents, model).await?; + Ok((embeddings, 0)) + } else if let Some(openai_client) = OPENAI_CLIENT.get() { + eprintln!("Fallback to OpenAI for embedding generation with model: {}", model); + generate_openai_embeddings(openai_client, documents, model).await + } else { + Err(ServerError::Config( + "No embedding client available (neither Ollama nor OpenAI)".to_string() + )) + } } \ No newline at end of file diff --git a/src/error.rs b/src/error.rs index f804a78..f1af916 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,25 +1,25 @@ -use rmcp::ServiceError; // Assuming ServiceError is the correct top-level error +use rmcp::ServiceError; use thiserror::Error; -use crate::doc_loader::DocLoaderError; // Need to import DocLoaderError from the sibling module +use crate::doc_loader::DocLoaderError; #[derive(Debug, Error)] pub enum ServerError { #[error("Environment variable not set: {0}")] MissingEnvVar(String), - // MissingArgument removed as clap handles this now #[error("Configuration Error: {0}")] Config(String), - #[error("MCP Service Error: {0}")] - Mcp(#[from] ServiceError), // Use ServiceError + Mcp(#[from] ServiceError), #[error("IO Error: {0}")] Io(#[from] std::io::Error), #[error("Document Loading Error: {0}")] DocLoader(#[from] DocLoaderError), #[error("OpenAI Error: {0}")] OpenAI(#[from] async_openai::error::OpenAIError), + #[error("Ollama Error: {0}")] + Ollama(#[from] ollama_rs::error::OllamaError), // Add Ollama error handling #[error("JSON Error: {0}")] - Json(#[from] serde_json::Error), // Add error for JSON deserialization + Json(#[from] serde_json::Error), #[error("Tiktoken Error: {0}")] Tiktoken(String), #[error("XDG Directory Error: {0}")] diff --git a/src/main.rs b/src/main.rs index cfd2cf1..c7dfec0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,33 +2,34 @@ mod doc_loader; mod embeddings; mod error; -mod server; // Keep server module as RustDocsServer is defined there +mod server; // Use necessary items from modules and crates use crate::{ doc_loader::Document, - embeddings::{generate_embeddings, CachedDocumentEmbedding, OPENAI_CLIENT}, + embeddings::{generate_embeddings, CachedDocumentEmbedding, OPENAI_CLIENT, OLLAMA_CLIENT}, error::ServerError, - server::RustDocsServer, // Import the updated RustDocsServer + server::RustDocsServer, }; use async_openai::{Client as OpenAIClient, config::OpenAIConfig}; use bincode::config; use cargo::core::PackageIdSpec; -use clap::Parser; // Import clap Parser +use clap::Parser; use ndarray::Array1; -// Import rmcp items needed for the new approach +use ollama_rs::Ollama; use rmcp::{ - transport::io::stdio, // Use the standard stdio transport - ServiceExt, // Import the ServiceExt trait for .serve() and .waiting() + transport::io::stdio, + ServiceExt, }; use std::{ collections::hash_map::DefaultHasher, env, fs::{self, File}, - hash::{Hash, Hasher}, // Import hashing utilities + hash::{Hash, Hasher}, io::BufReader, path::PathBuf, }; +// Removed unused url import #[cfg(not(target_os = "windows"))] use xdg::BaseDirectories; @@ -38,12 +39,24 @@ use xdg::BaseDirectories; #[command(author, version, about, long_about = None)] struct Cli { /// The package ID specification (e.g., "serde@^1.0", "tokio"). - #[arg()] // Positional argument + #[arg()] package_spec: String, /// Optional features to enable for the crate when generating documentation. - #[arg(short = 'F', long, value_delimiter = ',', num_args = 0..)] // Allow multiple comma-separated values + #[arg(short = 'F', long, value_delimiter = ',', num_args = 0..)] features: Option>, + + /// Use OpenAI instead of Ollama for embeddings (fallback mode) + #[arg(long)] + use_openai: bool, + + /// Specify Ollama host (default: localhost) + #[arg(long, default_value = "localhost")] + ollama_host: String, + + /// Specify Ollama port (default: 11434) + #[arg(long, default_value_t = 11434)] + ollama_port: u16, } // Helper function to create a stable hash from features @@ -52,12 +65,12 @@ fn hash_features(features: &Option>) -> String { .as_ref() .map(|f| { let mut sorted_features = f.clone(); - sorted_features.sort_unstable(); // Sort for consistent hashing + sorted_features.sort_unstable(); let mut hasher = DefaultHasher::new(); sorted_features.hash(&mut hasher); - format!("{:x}", hasher.finish()) // Return hex representation of hash + format!("{:x}", hasher.finish()) }) - .unwrap_or_else(|| "no_features".to_string()) // Use a specific string if no features + .unwrap_or_else(|| "no_features".to_string()) } #[tokio::main] @@ -67,9 +80,9 @@ async fn main() -> Result<(), ServerError> { // --- Parse CLI Arguments --- let cli = Cli::parse(); - let specid_str = cli.package_spec.trim().to_string(); // Trim whitespace + let specid_str = cli.package_spec.trim().to_string(); let features = cli.features.map(|f| { - f.into_iter().map(|s| s.trim().to_string()).collect() // Trim each feature + f.into_iter().map(|s| s.trim().to_string()).collect() }); // Parse the specid string @@ -91,19 +104,81 @@ async fn main() -> Result<(), ServerError> { specid_str, crate_name, crate_version_req, features ); - // --- Determine Paths (incorporating features) --- + // --- Initialize Clients --- + + // Initialize Ollama client (unless forced to use OpenAI) + if !cli.use_openai { + // Use the simpler approach: default() for localhost:11434, or construct URL for custom hosts + let ollama_client = if cli.ollama_host == "localhost" && cli.ollama_port == 11434 { + // Use the default for the most common case + eprintln!("Initializing Ollama client with default settings (localhost:11434)"); + Ollama::default() + } else { + // For custom hosts, construct the URL properly + let scheme_and_host = if cli.ollama_host.contains("://") { + cli.ollama_host.clone() + } else { + format!("http://{}", cli.ollama_host) + }; + eprintln!("Initializing Ollama client at {}:{}", scheme_and_host, cli.ollama_port); + Ollama::new(scheme_and_host, cli.ollama_port) + }; + + // Test Ollama connection + match ollama_client.show_model_info("nomic-embed-text".to_string()).await { + Ok(_) => { + eprintln!("✓ Connected to Ollama, nomic-embed-text model available"); + OLLAMA_CLIENT.set(ollama_client) + .map_err(|_| ServerError::Config("Failed to set Ollama client".to_string()))?; + } + Err(e) => { + eprintln!("⚠ Failed to connect to Ollama or nomic-embed-text not available: {}", e); + eprintln!("Make sure Ollama is running and pull the model with: ollama pull nomic-embed-text"); + return Err(ServerError::Config(format!( + "Ollama connection failed: {}. Try using --use-openai flag as fallback.", + e + ))); + } + } + } + + // Initialize OpenAI client (for chat completion and fallback) + let openai_client = if let Ok(api_base) = env::var("OPENAI_API_BASE") { + let config = OpenAIConfig::new().with_api_base(api_base); + OpenAIClient::with_config(config) + } else { + OpenAIClient::new() + }; + + // Always set OpenAI client for chat completion + OPENAI_CLIENT.set(openai_client.clone()) + .map_err(|_| ServerError::Config("Failed to set OpenAI client".to_string()))?; + + // Check if we have any embedding client + if OLLAMA_CLIENT.get().is_none() && !cli.use_openai { + eprintln!("No Ollama client available and not forced to use OpenAI"); + let _openai_api_key = env::var("OPENAI_API_KEY") + .map_err(|_| ServerError::MissingEnvVar("OPENAI_API_KEY".to_string()))?; + eprintln!("Falling back to OpenAI for embeddings"); + } - // Sanitize the version requirement string + // --- Determine Paths (incorporating features and model type) --- let sanitized_version_req = crate_version_req .replace(|c: char| !c.is_alphanumeric() && c != '.' && c != '-', "_"); - // Generate a stable hash for the features to use in the path let features_hash = hash_features(&features); + + // Include model type in cache path to avoid conflicts + let model_type = if OLLAMA_CLIENT.get().is_some() && !cli.use_openai { + "ollama" + } else { + "openai" + }; - // Construct the relative path component including features hash let embeddings_relative_path = PathBuf::from(&crate_name) .join(&sanitized_version_req) - .join(&features_hash) // Add features hash as a directory level + .join(&features_hash) + .join(model_type) // Separate cache for different embedding models .join("embeddings.bin"); #[cfg(not(target_os = "windows"))] @@ -121,7 +196,6 @@ async fn main() -> Result<(), ServerError> { ServerError::Config("Could not determine cache directory on Windows".to_string()) })?; let app_cache_dir = cache_dir.join("rustdocs-mcp-server"); - // Ensure the base app cache directory exists fs::create_dir_all(&app_cache_dir).map_err(ServerError::Io)?; app_cache_dir.join(embeddings_relative_path) }; @@ -181,17 +255,6 @@ async fn main() -> Result<(), ServerError> { let mut generation_cost: Option = None; let mut documents_for_server: Vec = loaded_documents_from_cache.unwrap_or_default(); - // --- Initialize OpenAI Client (needed for question embedding even if cache hit) --- - let openai_client = if let Ok(api_base) = env::var("OPENAI_API_BASE") { - let config = OpenAIConfig::new().with_api_base(api_base); - OpenAIClient::with_config(config) - } else { - OpenAIClient::new() - }; - OPENAI_CLIENT - .set(openai_client.clone()) // Clone the client for the OnceCell - .expect("Failed to set OpenAI client"); - let final_embeddings = match loaded_embeddings { Some(embeddings) => { eprintln!("Using embeddings and documents loaded from cache."); @@ -200,33 +263,39 @@ async fn main() -> Result<(), ServerError> { None => { eprintln!("Proceeding with documentation loading and embedding generation."); - let _openai_api_key = env::var("OPENAI_API_KEY") - .map_err(|_| ServerError::MissingEnvVar("OPENAI_API_KEY".to_string()))?; - eprintln!( "Loading documents for crate: {} (Version Req: {}, Features: {:?})", crate_name, crate_version_req, features ); - // Pass features to load_documents let loaded_documents = - doc_loader::load_documents(&crate_name, &crate_version_req, features.as_ref())?; // Pass features here + doc_loader::load_documents(&crate_name, &crate_version_req, features.as_ref())?; eprintln!("Loaded {} documents.", loaded_documents.len()); documents_for_server = loaded_documents.clone(); eprintln!("Generating embeddings..."); - let embedding_model: String = env::var("EMBEDDING_MODEL") - .unwrap_or_else(|_| "text-embedding-3-small".to_string()); - let (generated_embeddings, total_tokens) = - generate_embeddings(&openai_client, &loaded_documents, &embedding_model).await?; + let embedding_model = if OLLAMA_CLIENT.get().is_some() && !cli.use_openai { + "nomic-embed-text".to_string() + } else { + env::var("EMBEDDING_MODEL") + .unwrap_or_else(|_| "text-embedding-3-small".to_string()) + }; - let cost_per_million = 0.02; - let estimated_cost = (total_tokens as f64 / 1_000_000.0) * cost_per_million; - eprintln!( - "Embedding generation cost for {} tokens: ${:.6}", - total_tokens, estimated_cost - ); - generated_tokens = Some(total_tokens); - generation_cost = Some(estimated_cost); + let (generated_embeddings, total_tokens) = + generate_embeddings(&loaded_documents, &embedding_model).await?; + + // Only calculate cost for OpenAI + if cli.use_openai || OLLAMA_CLIENT.get().is_none() { + let cost_per_million = 0.02; + let estimated_cost = (total_tokens as f64 / 1_000_000.0) * cost_per_million; + eprintln!( + "Embedding generation cost for {} tokens: ${:.6}", + total_tokens, estimated_cost + ); + generated_tokens = Some(total_tokens); + generation_cost = Some(estimated_cost); + } else { + eprintln!("Generated embeddings using Ollama (local, no cost)"); + } eprintln!( "Saving generated documents and embeddings to: {:?}", @@ -293,50 +362,62 @@ async fn main() -> Result<(), ServerError> { .map(|f| format!(" Features: {:?}", f)) .unwrap_or_default(); + let model_info = if OLLAMA_CLIENT.get().is_some() && !cli.use_openai { + "using Ollama/nomic-embed-text".to_string() + } else { + "using OpenAI".to_string() + }; + let startup_message = if loaded_from_cache { format!( - "Server for crate '{}' (Version Req: '{}'{}) initialized. Loaded {} embeddings from cache.", - crate_name, crate_version_req, features_str, final_embeddings.len() + "Server for crate '{}' (Version Req: '{}'{}) initialized. Loaded {} embeddings from cache ({}).", + crate_name, crate_version_req, features_str, final_embeddings.len(), model_info ) } else { let tokens = generated_tokens.unwrap_or(0); let cost = generation_cost.unwrap_or(0.0); - format!( - "Server for crate '{}' (Version Req: '{}'{}) initialized. Generated {} embeddings for {} tokens (Est. Cost: ${:.6}).", - crate_name, - crate_version_req, - features_str, - final_embeddings.len(), - tokens, - cost - ) + if OLLAMA_CLIENT.get().is_some() && !cli.use_openai { + format!( + "Server for crate '{}' (Version Req: '{}'{}) initialized. Generated {} embeddings using Ollama (local).", + crate_name, + crate_version_req, + features_str, + final_embeddings.len(), + ) + } else { + format!( + "Server for crate '{}' (Version Req: '{}'{}) initialized. Generated {} embeddings for {} tokens (Est. Cost: ${:.6}) using OpenAI.", + crate_name, + crate_version_req, + features_str, + final_embeddings.len(), + tokens, + cost + ) + } }; - // Create the service instance using the updated ::new() let service = RustDocsServer::new( - crate_name.clone(), // Pass crate_name directly + crate_name.clone(), documents_for_server, final_embeddings, startup_message, )?; - // --- Use standard stdio transport and ServiceExt --- eprintln!("Rust Docs MCP server starting via stdio..."); - // Serve the server using the ServiceExt trait and standard stdio transport let server_handle = service.serve(stdio()).await.map_err(|e| { eprintln!("Failed to start server: {:?}", e); - ServerError::McpRuntime(e.to_string()) // Use the new McpRuntime variant + ServerError::McpRuntime(e.to_string()) })?; - eprintln!("{} Docs MCP server running...", &crate_name); + eprintln!("{} Docs MCP server running {} ...", &crate_name, model_info); - // Wait for the server to complete (e.g., stdin closed) server_handle.waiting().await.map_err(|e| { eprintln!("Server encountered an error while running: {:?}", e); - ServerError::McpRuntime(e.to_string()) // Use the new McpRuntime variant + ServerError::McpRuntime(e.to_string()) })?; eprintln!("Rust Docs MCP server stopped."); Ok(()) -} +} \ No newline at end of file diff --git a/src/server.rs b/src/server.rs index 9e886ca..e1f2b19 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,31 +1,30 @@ use crate::{ doc_loader::Document, - embeddings::{OPENAI_CLIENT, cosine_similarity}, - error::ServerError, // Keep ServerError for ::new() + embeddings::{OPENAI_CLIENT, OLLAMA_CLIENT, cosine_similarity, generate_single_ollama_embedding}, + error::ServerError, }; use async_openai::{ types::{ ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs, }, - // Client as OpenAIClient, // Removed unused import }; use ndarray::Array1; -use rmcp::model::AnnotateAble; // Import trait for .no_annotation() +use rmcp::model::AnnotateAble; use rmcp::{ Error as McpError, Peer, - ServerHandler, // Import necessary rmcp items + ServerHandler, model::{ CallToolResult, Content, GetPromptRequestParam, GetPromptResult, - /* EmptyObject, ErrorCode, */ Implementation, - ListPromptsResult, // Removed EmptyObject, ErrorCode + Implementation, + ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult, - LoggingLevel, // Uncommented ListToolsResult + LoggingLevel, LoggingMessageNotification, LoggingMessageNotificationMethod, LoggingMessageNotificationParam, @@ -33,7 +32,6 @@ use rmcp::{ PaginatedRequestParam, ProtocolVersion, RawResource, - /* Prompt, PromptArgument, PromptMessage, PromptMessageContent, PromptMessageRole, */ // Removed Prompt types ReadResourceRequestParam, ReadResourceResult, Resource, @@ -45,55 +43,53 @@ use rmcp::{ service::{RequestContext, RoleServer}, tool, }; -use schemars::JsonSchema; // Import JsonSchema -use serde::Deserialize; // Import Deserialize +use schemars::JsonSchema; +use serde::Deserialize; use serde_json::json; -use std::{/* borrow::Cow, */ env, sync::Arc}; // Removed borrow::Cow +use std::{env, sync::Arc}; use tokio::sync::Mutex; +// Add Ollama imports for chat completion - using the correct import paths +use ollama_rs::generation::chat::{ChatMessage, MessageRole}; +use ollama_rs::generation::chat::request::ChatMessageRequest; + // --- Argument Struct for the Tool --- #[derive(Debug, Deserialize, JsonSchema)] struct QueryRustDocsArgs { #[schemars(description = "The specific question about the crate's API or usage.")] question: String, - // Removed crate_name field as it's implicit to the server instance } // --- Main Server Struct --- -// No longer needs ServerState, holds data directly -#[derive(Clone)] // Add Clone for tool macro requirements +#[derive(Clone)] pub struct RustDocsServer { - crate_name: Arc, // Use Arc for cheap cloning + crate_name: Arc, documents: Arc>, embeddings: Arc)>>, - peer: Arc>>>, // Uses tokio::sync::Mutex - startup_message: Arc>>, // Keep the message itself - startup_message_sent: Arc>, // Flag to track if sent (using tokio::sync::Mutex) - // tool_name and info are handled by ServerHandler/macros now + peer: Arc>>>, + startup_message: Arc>>, + startup_message_sent: Arc>, } impl RustDocsServer { - // Updated constructor pub fn new( crate_name: String, documents: Vec, embeddings: Vec<(String, Array1)>, startup_message: String, ) -> Result { - // Keep ServerError for potential future init errors Ok(Self { crate_name: Arc::new(crate_name), documents: Arc::new(documents), embeddings: Arc::new(embeddings), - peer: Arc::new(Mutex::new(None)), // Uses tokio::sync::Mutex - startup_message: Arc::new(Mutex::new(Some(startup_message))), // Initialize message - startup_message_sent: Arc::new(Mutex::new(false)), // Initialize flag to false + peer: Arc::new(Mutex::new(None)), + startup_message: Arc::new(Mutex::new(Some(startup_message))), + startup_message_sent: Arc::new(Mutex::new(false)), }) } - // Helper function to send log messages via MCP notification (remains mostly the same) pub fn send_log(&self, level: LoggingLevel, message: String) { let peer_arc = Arc::clone(&self.peer); tokio::spawn(async move { @@ -119,25 +115,170 @@ impl RustDocsServer { }); } - // Helper for creating simple text resources (like in counter example) fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { RawResource::new(uri, name.to_string()).no_annotation() } + + /// Generate embedding for a question using the same model that was used for documents + async fn generate_question_embedding(&self, question: &str) -> Result, McpError> { + // First try Ollama (preferred for consistency) + if let Some(ollama_client) = OLLAMA_CLIENT.get() { + match generate_single_ollama_embedding(ollama_client, question, "nomic-embed-text").await { + Ok(embedding) => return Ok(embedding), + Err(e) => { + eprintln!("Failed to generate question embedding with Ollama: {}", e); + self.send_log( + LoggingLevel::Warning, + format!("Ollama embedding failed, trying OpenAI fallback: {}", e), + ); + } + } + } + + // Fallback to OpenAI if Ollama fails or is not available + if let Some(openai_client) = OPENAI_CLIENT.get() { + let embedding_model: String = env::var("EMBEDDING_MODEL") + .unwrap_or_else(|_| "text-embedding-3-small".to_string()); + + let question_embedding_request = CreateEmbeddingRequestArgs::default() + .model(embedding_model) + .input(question.to_string()) + .build() + .map_err(|e| { + McpError::internal_error(format!("Failed to build embedding request: {}", e), None) + })?; + + let question_embedding_response = openai_client + .embeddings() + .create(question_embedding_request) + .await + .map_err(|e| McpError::internal_error(format!("OpenAI API error: {}", e), None))?; + + let question_embedding = question_embedding_response.data.first().ok_or_else(|| { + McpError::internal_error("Failed to get embedding for question", None) + })?; + + return Ok(Array1::from(question_embedding.embedding.clone())); + } + + Err(McpError::internal_error( + "No embedding client available (neither Ollama nor OpenAI)", + None, + )) + } + + /// Generate chat completion using Ollama or OpenAI + async fn generate_chat_completion( + &self, + system_prompt: &str, + user_prompt: &str, + ) -> Result { + // First try Ollama (preferred for consistency) + if let Some(ollama_client) = OLLAMA_CLIENT.get() { + // Get the chat model from environment variable, default to llama3.2 + let chat_model = env::var("OLLAMA_CHAT_MODEL") + .unwrap_or_else(|_| "llama3.2".to_string()); + + self.send_log( + LoggingLevel::Info, + format!("Using Ollama for chat completion with model: {}", chat_model), + ); + + // Create the chat messages - system message followed by user message + let messages = vec![ + ChatMessage::system(system_prompt.to_string()), + ChatMessage::user(user_prompt.to_string()), + ]; + + // Create the chat request + let chat_request = ChatMessageRequest::new(chat_model, messages); + + match ollama_client.send_chat_messages(chat_request).await { + Ok(response) => { + // The response.message is a ChatMessage directly, not an Option + // We need to access its content field + return Ok(response.message.content); + } + Err(e) => { + eprintln!("Failed to generate chat completion with Ollama: {}", e); + self.send_log( + LoggingLevel::Warning, + format!("Ollama chat failed, trying OpenAI fallback: {}", e), + ); + } + } + } + + // Fallback to OpenAI if Ollama fails or is not available + if let Some(openai_client) = OPENAI_CLIENT.get() { + self.send_log( + LoggingLevel::Info, + "Using OpenAI for chat completion".to_string(), + ); + + let llm_model: String = env::var("LLM_MODEL") + .unwrap_or_else(|_| "gpt-4o-mini-2024-07-18".to_string()); + + let chat_request = CreateChatCompletionRequestArgs::default() + .model(llm_model) + .messages(vec![ + ChatCompletionRequestSystemMessageArgs::default() + .content(system_prompt) + .build() + .map_err(|e| { + McpError::internal_error( + format!("Failed to build system message: {}", e), + None, + ) + })? + .into(), + ChatCompletionRequestUserMessageArgs::default() + .content(user_prompt) + .build() + .map_err(|e| { + McpError::internal_error( + format!("Failed to build user message: {}", e), + None, + ) + })? + .into(), + ]) + .build() + .map_err(|e| { + McpError::internal_error( + format!("Failed to build chat request: {}", e), + None, + ) + })?; + + let chat_response = openai_client.chat().create(chat_request).await.map_err(|e| { + McpError::internal_error(format!("OpenAI chat API error: {}", e), None) + })?; + + return chat_response + .choices + .first() + .and_then(|choice| choice.message.content.clone()) + .ok_or_else(|| McpError::internal_error("No response from OpenAI", None)); + } + + Err(McpError::internal_error( + "No chat client available (neither Ollama nor OpenAI)", + None, + )) + } } // --- Tool Implementation --- -#[tool(tool_box)] // Add tool_box here as well, mirroring the example -// Tool methods go in a regular impl block +#[tool(tool_box)] impl RustDocsServer { - // Define the tool using the tool macro - // Name removed; will be handled dynamically by overriding list_tools/get_tool #[tool( description = "Query documentation for a specific Rust crate using semantic search and LLM summarization." )] async fn query_rust_docs( &self, - #[tool(aggr)] // Aggregate arguments into the struct + #[tool(aggr)] args: QueryRustDocsArgs, ) -> Result { // --- Send Startup Message (if not already sent) --- @@ -145,20 +286,15 @@ impl RustDocsServer { if !*sent_guard { let mut msg_guard = self.startup_message.lock().await; if let Some(message) = msg_guard.take() { - // Take the message out self.send_log(LoggingLevel::Info, message); - *sent_guard = true; // Mark as sent + *sent_guard = true; } - // Drop guards explicitly to avoid holding locks longer than needed drop(msg_guard); drop(sent_guard); } else { - // Drop guard if already sent drop(sent_guard); } - // Argument validation for crate_name removed - let question = &args.question; // Log received query via MCP @@ -170,32 +306,8 @@ impl RustDocsServer { ), ); - // --- Embedding Generation for Question --- - let client = OPENAI_CLIENT - .get() - .ok_or_else(|| McpError::internal_error("OpenAI client not initialized", None))?; - - let embedding_model: String = - env::var("EMBEDDING_MODEL").unwrap_or_else(|_| "text-embedding-3-small".to_string()); - let question_embedding_request = CreateEmbeddingRequestArgs::default() - .model(embedding_model) - .input(question.to_string()) - .build() - .map_err(|e| { - McpError::internal_error(format!("Failed to build embedding request: {}", e), None) - })?; - - let question_embedding_response = client - .embeddings() - .create(question_embedding_request) - .await - .map_err(|e| McpError::internal_error(format!("OpenAI API error: {}", e), None))?; - - let question_embedding = question_embedding_response.data.first().ok_or_else(|| { - McpError::internal_error("Failed to get embedding for question", None) - })?; - - let question_vector = Array1::from(question_embedding.embedding.clone()); + // --- Generate question embedding using the same model as documents --- + let question_vector = self.generate_question_embedding(question).await?; // --- Find Best Matching Document --- let mut best_match: Option<(&str, f32)> = None; @@ -208,8 +320,15 @@ impl RustDocsServer { // --- Generate Response using LLM --- let response_text = match best_match { - Some((best_path, _score)) => { - eprintln!("Best match found: {}", best_path); + Some((best_path, score)) => { + eprintln!("Best match found: {} (similarity: {:.3})", best_path, score); + + // Log the similarity score via MCP + self.send_log( + LoggingLevel::Info, + format!("Best matching document: {} (similarity: {:.3})", best_path, score), + ); + let context_doc = self.documents.iter().find(|doc| doc.path == best_path); if let Some(doc) = context_doc { @@ -225,49 +344,8 @@ impl RustDocsServer { doc.content, question ); - let llm_model: String = env::var("LLM_MODEL") - .unwrap_or_else(|_| "gpt-4o-mini-2024-07-18".to_string()); - let chat_request = CreateChatCompletionRequestArgs::default() - .model(llm_model) - .messages(vec![ - ChatCompletionRequestSystemMessageArgs::default() - .content(system_prompt) - .build() - .map_err(|e| { - McpError::internal_error( - format!("Failed to build system message: {}", e), - None, - ) - })? - .into(), - ChatCompletionRequestUserMessageArgs::default() - .content(user_prompt) - .build() - .map_err(|e| { - McpError::internal_error( - format!("Failed to build user message: {}", e), - None, - ) - })? - .into(), - ]) - .build() - .map_err(|e| { - McpError::internal_error( - format!("Failed to build chat request: {}", e), - None, - ) - })?; - - let chat_response = client.chat().create(chat_request).await.map_err(|e| { - McpError::internal_error(format!("OpenAI chat API error: {}", e), None) - })?; - - chat_response - .choices - .first() - .and_then(|choice| choice.message.content.clone()) - .unwrap_or_else(|| "Error: No response from LLM.".to_string()) + // Use the new chat completion method that supports both Ollama and OpenAI + self.generate_chat_completion(&system_prompt, &user_prompt).await? } else { "Error: Could not find content for best matching document.".to_string() } @@ -285,42 +363,45 @@ impl RustDocsServer { // --- ServerHandler Implementation --- -#[tool(tool_box)] // Use imported tool macro directly +#[tool(tool_box)] impl ServerHandler for RustDocsServer { fn get_info(&self) -> ServerInfo { - // Define capabilities using the builder let capabilities = ServerCapabilities::builder() - .enable_tools() // Enable tools capability - .enable_logging() // Enable logging capability - // Add other capabilities like resources, prompts if needed later + .enable_tools() + .enable_logging() .build(); + // Determine which embedding and chat models are being used + let model_info = if OLLAMA_CLIENT.get().is_some() { + let chat_model = env::var("OLLAMA_CHAT_MODEL") + .unwrap_or_else(|_| "llama3.2".to_string()); + format!("locally with Ollama (nomic-embed-text for embeddings, {} for chat)", chat_model) + } else { + "with OpenAI".to_string() + }; + ServerInfo { - protocol_version: ProtocolVersion::V_2024_11_05, // Use latest known version + protocol_version: ProtocolVersion::V_2024_11_05, capabilities, server_info: Implementation { name: "rust-docs-mcp-server".to_string(), version: env!("CARGO_PKG_VERSION").to_string(), }, - // Provide instructions based on the specific crate instructions: Some(format!( "This server provides tools to query documentation for the '{}' crate. \ Use the 'query_rust_docs' tool with a specific question to get information \ - about its API, usage, and examples, derived from its official documentation.", - self.crate_name + about its API, usage, and examples, derived from its official documentation. \ + Running {}.", + self.crate_name, model_info )), } } - // --- Placeholder Implementations for other ServerHandler methods --- - // Implement these properly if resource/prompt features are added later. - async fn list_resources( &self, _request: PaginatedRequestParam, _context: RequestContext, ) -> Result { - // Example: Return the crate name as a resource Ok(ListResourcesResult { resources: vec![ self._create_resource_text(&format!("crate://{}", self.crate_name), "crate_name"), @@ -338,7 +419,7 @@ impl ServerHandler for RustDocsServer { if request.uri == expected_uri { Ok(ReadResourceResult { contents: vec![ResourceContents::text( - self.crate_name.as_str(), // Explicitly get &str from Arc + self.crate_name.as_str(), &request.uri, )], }) @@ -357,7 +438,7 @@ impl ServerHandler for RustDocsServer { ) -> Result { Ok(ListPromptsResult { next_cursor: None, - prompts: Vec::new(), // No prompts defined yet + prompts: Vec::new(), }) } @@ -367,7 +448,6 @@ impl ServerHandler for RustDocsServer { _context: RequestContext, ) -> Result { Err(McpError::invalid_params( - // Or prompt_not_found if that exists format!("Prompt not found: {}", request.name), None, )) @@ -380,7 +460,7 @@ impl ServerHandler for RustDocsServer { ) -> Result { Ok(ListResourceTemplatesResult { next_cursor: None, - resource_templates: Vec::new(), // No templates defined yet + resource_templates: Vec::new(), }) } -} +} \ No newline at end of file