Skip to content

Commit c70a393

Browse files
authored
Merge pull request #2 from pashadia/main
Configure models and API base from environment
2 parents 4a8d7cf + e2294df commit c70a393

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

src/main.rs

+11-8
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::{
1111
error::ServerError,
1212
server::RustDocsServer, // Import the updated RustDocsServer
1313
};
14-
use async_openai::Client as OpenAIClient;
14+
use async_openai::{Client as OpenAIClient, config::OpenAIConfig};
1515
use bincode::config;
1616
use cargo::core::PackageIdSpec;
1717
use clap::Parser; // Import clap Parser
@@ -182,7 +182,12 @@ async fn main() -> Result<(), ServerError> {
182182
let mut documents_for_server: Vec<Document> = loaded_documents_from_cache.unwrap_or_default();
183183

184184
// --- Initialize OpenAI Client (needed for question embedding even if cache hit) ---
185-
let openai_client = OpenAIClient::new();
185+
let openai_client = if let Ok(api_base) = env::var("OPENAI_API_BASE") {
186+
let config = OpenAIConfig::new().with_api_base(api_base);
187+
OpenAIClient::with_config(config)
188+
} else {
189+
OpenAIClient::new()
190+
};
186191
OPENAI_CLIENT
187192
.set(openai_client.clone()) // Clone the client for the OnceCell
188193
.expect("Failed to set OpenAI client");
@@ -209,12 +214,10 @@ async fn main() -> Result<(), ServerError> {
209214
documents_for_server = loaded_documents.clone();
210215

211216
eprintln!("Generating embeddings...");
212-
let (generated_embeddings, total_tokens) = generate_embeddings(
213-
&openai_client,
214-
&loaded_documents,
215-
"text-embedding-3-small",
216-
)
217-
.await?;
217+
let embedding_model: String = env::var("EMBEDDING_MODEL")
218+
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
219+
let (generated_embeddings, total_tokens) =
220+
generate_embeddings(&openai_client, &loaded_documents, &embedding_model).await?;
218221

219222
let cost_per_million = 0.02;
220223
let estimated_cost = (total_tokens as f64 / 1_000_000.0) * cost_per_million;

src/server.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,10 @@ impl RustDocsServer {
175175
.get()
176176
.ok_or_else(|| McpError::internal_error("OpenAI client not initialized", None))?;
177177

178+
let embedding_model: String =
179+
env::var("EMBEDDING_MODEL").unwrap_or_else(|_| "text-embedding-3-small".to_string());
178180
let question_embedding_request = CreateEmbeddingRequestArgs::default()
179-
.model("text-embedding-3-small")
181+
.model(embedding_model)
180182
.input(question.to_string())
181183
.build()
182184
.map_err(|e| {
@@ -223,8 +225,10 @@ impl RustDocsServer {
223225
doc.content, question
224226
);
225227

228+
let llm_model: String = env::var("LLM_MODEL")
229+
.unwrap_or_else(|_| "gpt-4o-mini-2024-07-18".to_string());
226230
let chat_request = CreateChatCompletionRequestArgs::default()
227-
.model("gpt-4o-mini-2024-07-18")
231+
.model(llm_model)
228232
.messages(vec![
229233
ChatCompletionRequestSystemMessageArgs::default()
230234
.content(system_prompt)
@@ -379,5 +383,4 @@ impl ServerHandler for RustDocsServer {
379383
resource_templates: Vec::new(), // No templates defined yet
380384
})
381385
}
382-
383386
}

0 commit comments

Comments
 (0)