diff --git a/src/api.rs b/src/api.rs index 9bf9648..7a46657 100644 --- a/src/api.rs +++ b/src/api.rs @@ -6,7 +6,8 @@ use crate::error::Result; use crate::provider::{AnyLLMProvider, CompletionStream, ProviderConfig}; use crate::types::{ - ChatCompletion, CompletionParams, Message, ReasoningEffort, StopSequence, Tool, ToolChoice, + ChatCompletion, CompletionParams, Message, ReasoningEffort, RerankParams, RerankResponse, + StopSequence, Tool, ToolChoice, }; use crate::Provider; use serde_json::Value; @@ -308,6 +309,66 @@ pub async fn completion_stream( provider.completion_stream(params).await } +/// Options for a rerank request. +#[derive(Debug, Clone, Default)] +pub struct RerankOptions { + /// API key (if not set, uses environment variable). + pub api_key: Option, + + /// API base URL (for custom endpoints/proxies). + pub api_base: Option, + + /// Maximum number of results to return. + pub top_n: Option, + + /// Maximum tokens per document for truncation. + pub max_tokens_per_doc: Option, + + /// User identifier for abuse detection. + pub user: Option, +} + +impl From for ProviderConfig { + fn from(options: RerankOptions) -> Self { + ProviderConfig { + api_key: options.api_key, + api_base: options.api_base, + extra: Default::default(), + } + } +} + +/// Rerank documents by relevance to a query. +/// +/// # Arguments +/// * `model` - Model identifier (e.g., "cohere:rerank-v3.5") +/// * `query` - The search query +/// * `documents` - Documents to rerank +/// * `options` - Additional options (API key, base URL, top_n, etc.) +/// +/// # Returns +/// A `RerankResponse` with results sorted by `relevance_score` descending. +/// +/// # Errors +/// Returns `AnyLLMError` if the provider does not support reranking or the request fails. +pub async fn rerank( + model: &str, + query: &str, + documents: Vec, + options: RerankOptions, +) -> Result { + let provider = AnyLLMProvider::

::from_config(options.clone().into())?; + let params = RerankParams { + model_id: model.to_string(), + query: query.to_string(), + documents, + top_n: options.top_n, + max_tokens_per_doc: options.max_tokens_per_doc, + user: options.user, + }; + provider.rerank(params).await +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index f138874..04da66b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -137,7 +137,7 @@ pub mod providers; pub mod types; // Re-export main types for convenience -pub use api::{completion, completion_stream, CompletionOptions}; +pub use api::{completion, completion_stream, rerank, CompletionOptions, RerankOptions}; pub use error::{AnyLLMError, Result}; pub use provider::{Provider, ProviderConfig}; pub use types::{ @@ -146,6 +146,6 @@ pub use types::{ ChunkAccumulator, ChunkChoice, CompletionParams, CompletionUsage, Content, ContentPart, CreateBatchParams, Function, ImageUrl, ListBatchesOptions, Message, ModerationContentPart, ModerationImageUrl, ModerationInput, ModerationParams, ModerationResponse, ModerationResult, - Reasoning, ReasoningEffort, Role, StopSequence, Tool, ToolCall, ToolCallDelta, ToolChoice, - ToolFunction, + Reasoning, ReasoningEffort, RerankMeta, RerankParams, RerankResponse, RerankResult, + RerankUsage, Role, StopSequence, Tool, ToolCall, ToolCallDelta, ToolChoice, ToolFunction, }; diff --git a/src/provider/interface.rs b/src/provider/interface.rs index bf3c90c..e72c95c 100644 --- a/src/provider/interface.rs +++ b/src/provider/interface.rs @@ -3,7 +3,7 @@ use futures::Future; use crate::{ error::{AnyLLMError, Result}, types::Content, - types::{ChatCompletion, CompletionParams}, + types::{ChatCompletion, CompletionParams, RerankParams, RerankResponse}, }; use super::{config::ProviderConfig, CompletionStream}; @@ -21,6 +21,7 @@ pub trait Provider: Sized + Send + Sync { const SUPPORTS_IMAGES: bool = false; const SUPPORTS_REASONING: bool = false; const SUPPORTS_PDF: bool = false; + const SUPPORTS_RERANK: bool = false; fn api_key(config: &ProviderConfig) -> Option { config @@ -106,4 +107,32 @@ pub trait Provider: Sized + Send + Sync { self.completion_stream_fn(params).await } } + + // ── Rerank ────────────────────────────────────────────────── + + /// Rerank documents by relevance to a query. + /// Providers that support reranking must override `rerank_fn`. + fn rerank_fn( + &self, + _params: RerankParams, + ) -> impl Future> + Send { + async { + Err(AnyLLMError::invalid_request::( + "rerank not supported by this provider", + )) + } + } + + /// Rerank documents by relevance to a query (with validation). + fn rerank(&self, params: RerankParams) -> impl Future> + Send { + async { + if !Self::SUPPORTS_RERANK { + return Err(AnyLLMError::unsupported_parameter::( + "rerank", + "this provider does not support reranking", + )); + } + self.rerank_fn(params).await + } + } } diff --git a/src/provider/mod.rs b/src/provider/mod.rs index f6a09bf..ec51269 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -6,7 +6,7 @@ use futures::Stream; use crate::{ error::Result, - types::{ChatCompletion, ChatCompletionChunk, CompletionParams}, + types::{ChatCompletion, ChatCompletionChunk, CompletionParams, RerankParams, RerankResponse}, }; mod config; @@ -35,4 +35,9 @@ impl AnyLLMProvider

{ pub async fn completion_stream(&self, params: CompletionParams) -> Result { self.0.completion_stream(params).await } + + /// Rerank documents by relevance to a query. + pub async fn rerank(&self, params: RerankParams) -> Result { + self.0.rerank(params).await + } } diff --git a/src/providers/gateway/mod.rs b/src/providers/gateway/mod.rs index 467b034..c61bfcd 100644 --- a/src/providers/gateway/mod.rs +++ b/src/providers/gateway/mod.rs @@ -22,7 +22,7 @@ use crate::error::{AnyLLMError, Result}; use crate::provider::{CompletionStream, Provider, ProviderConfig}; use crate::types::{ Batch, BatchResult, ChatCompletion, CompletionParams, CreateBatchParams, ListBatchesOptions, - ModerationParams, ModerationResponse, + ModerationParams, ModerationResponse, RerankParams, RerankResponse, }; mod models; @@ -202,6 +202,7 @@ impl Provider for Gateway { const SUPPORTS_IMAGES: bool = true; const SUPPORTS_REASONING: bool = true; const SUPPORTS_PDF: bool = true; + const SUPPORTS_RERANK: bool = true; fn from_config(config: ProviderConfig) -> Result { let api_base = config @@ -274,6 +275,28 @@ impl Provider for Gateway { GatewayStream::new(es, model).try_into() } + + async fn rerank_fn(&self, params: RerankParams) -> Result { + let body = models::rerank::GatewayRerankRequest::from(params); + + let response = self + .client + .post(format!("{}/v1/rerank", self.api_base)) + .json(&body) + .send() + .await + .map_err(AnyLLMError::from)?; + + let status = response.status().as_u16(); + if status != 200 { + return Err(convert_error(response).await); + } + + response + .json::() + .await + .map_err(AnyLLMError::from) + } } /// Resolve auth mode and build the appropriate HTTP headers. diff --git a/src/providers/gateway/models/mod.rs b/src/providers/gateway/models/mod.rs index 1576bda..2563916 100644 --- a/src/providers/gateway/models/mod.rs +++ b/src/providers/gateway/models/mod.rs @@ -1,3 +1,4 @@ pub mod request; +pub mod rerank; pub mod response; pub mod stream; diff --git a/src/providers/gateway/models/rerank.rs b/src/providers/gateway/models/rerank.rs new file mode 100644 index 0000000..3600977 --- /dev/null +++ b/src/providers/gateway/models/rerank.rs @@ -0,0 +1,68 @@ +use serde::Serialize; + +use crate::types::RerankParams; + +/// Gateway wire format for a rerank request. +#[derive(Debug, Serialize)] +pub struct GatewayRerankRequest { + pub model: String, + pub query: String, + pub documents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens_per_doc: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +impl From for GatewayRerankRequest { + fn from(params: RerankParams) -> Self { + Self { + model: params.model_id, + query: params.query, + documents: params.documents, + top_n: params.top_n, + max_tokens_per_doc: params.max_tokens_per_doc, + user: params.user, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_rerank_params() { + let params = RerankParams { + model_id: "cohere:rerank-v3.5".to_string(), + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + top_n: Some(2), + max_tokens_per_doc: None, + user: None, + }; + let req = GatewayRerankRequest::from(params); + assert_eq!(req.model, "cohere:rerank-v3.5"); + assert_eq!(req.query, "test query"); + assert_eq!(req.documents.len(), 2); + assert_eq!(req.top_n, Some(2)); + } + + #[test] + fn test_serialization_skips_none() { + let req = GatewayRerankRequest { + model: "test".to_string(), + query: "q".to_string(), + documents: vec!["d".to_string()], + top_n: None, + max_tokens_per_doc: None, + user: None, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(!json.contains("top_n")); + assert!(!json.contains("max_tokens_per_doc")); + assert!(!json.contains("user")); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 459e3fd..99f93e0 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -9,6 +9,7 @@ mod chunk; mod completion; mod message; mod moderation; +mod rerank; mod tool; mod usage; @@ -17,5 +18,6 @@ pub use chunk::*; pub use completion::*; pub use message::*; pub use moderation::*; +pub use rerank::*; pub use tool::*; pub use usage::*; diff --git a/src/types/rerank.rs b/src/types/rerank.rs new file mode 100644 index 0000000..41447e2 --- /dev/null +++ b/src/types/rerank.rs @@ -0,0 +1,50 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Parameters for a rerank request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RerankParams { + pub model_id: String, + pub query: String, + pub documents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens_per_doc: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +/// A single reranked document score. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct RerankResult { + pub index: u32, + pub relevance_score: f64, +} + +/// Provider-specific billing metadata. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct RerankMeta { + #[serde(skip_serializing_if = "Option::is_none")] + pub billed_units: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens: Option>, +} + +/// Normalized token usage for a rerank request. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RerankUsage { + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tokens: Option, +} + +/// Normalized rerank response, provider-agnostic. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct RerankResponse { + pub id: String, + pub results: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub meta: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} diff --git a/tests/test_gateway.rs b/tests/test_gateway.rs index c474233..3d32a01 100644 --- a/tests/test_gateway.rs +++ b/tests/test_gateway.rs @@ -1,7 +1,8 @@ use any_llm::providers::Gateway; +use any_llm::types::RerankParams; use any_llm::{ - AnyLLMError, Batch, BatchRequestItem, BatchResult, BatchStatus, CompletionOptions, - CreateBatchParams, ListBatchesOptions, Message, Provider, ProviderConfig, + rerank, AnyLLMError, Batch, BatchRequestItem, BatchResult, BatchStatus, CompletionOptions, + CreateBatchParams, ListBatchesOptions, Message, Provider, ProviderConfig, RerankOptions, }; use futures::StreamExt; use wiremock::matchers::{header, method, path, query_param}; @@ -538,6 +539,168 @@ async fn live_gateway_completion() { println!("Live response: {content}"); } +// --------------------------------------------------------------------------- +// Rerank tests (wiremock) +// --------------------------------------------------------------------------- + +fn rerank_response_json() -> String { + r#"{ + "id": "rerank-test-123", + "results": [ + {"index": 0, "relevance_score": 0.95}, + {"index": 2, "relevance_score": 0.80}, + {"index": 1, "relevance_score": 0.30} + ], + "meta": { + "billed_units": {"search_units": 1.0}, + "tokens": {"input_tokens": 100} + }, + "usage": {"total_tokens": 100} + }"# + .to_string() +} + +#[tokio::test] +async fn test_gateway_rerank() { + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/rerank")) + .respond_with(ResponseTemplate::new(200).set_body_string(rerank_response_json())) + .mount(&mock_server) + .await; + + let gateway = Gateway::from_config(ProviderConfig { + api_base: Some(mock_server.uri()), + api_key: Some("test-key".to_string()), + ..Default::default() + }) + .unwrap(); + + let result = gateway + .rerank(RerankParams { + model_id: "cohere:rerank-v3.5".to_string(), + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()], + top_n: Some(3), + max_tokens_per_doc: None, + user: None, + }) + .await + .unwrap(); + + assert_eq!(result.id, "rerank-test-123"); + assert_eq!(result.results.len(), 3); + assert!((result.results[0].relevance_score - 0.95).abs() < f64::EPSILON); + assert_eq!(result.usage.unwrap().total_tokens, Some(100)); +} + +#[tokio::test] +async fn test_gateway_rerank_401_error() { + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/rerank")) + .respond_with( + ResponseTemplate::new(401).set_body_string(r#"{"error": {"message": "Unauthorized"}}"#), + ) + .mount(&mock_server) + .await; + + let gateway = Gateway::from_config(ProviderConfig { + api_base: Some(mock_server.uri()), + api_key: Some("bad-key".to_string()), + ..Default::default() + }) + .unwrap(); + + let err = gateway + .rerank(RerankParams { + model_id: "cohere:rerank-v3.5".to_string(), + query: "test".to_string(), + documents: vec!["doc".to_string()], + top_n: None, + max_tokens_per_doc: None, + user: None, + }) + .await + .unwrap_err(); + + assert!(matches!(err, AnyLLMError::Authentication { .. })); +} + +#[tokio::test] +async fn test_gateway_rerank_429_error() { + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/rerank")) + .respond_with( + ResponseTemplate::new(429) + .set_body_string(r#"{"error": {"message": "Rate limited"}}"#) + .append_header("retry-after", "60"), + ) + .mount(&mock_server) + .await; + + let gateway = Gateway::from_config(ProviderConfig { + api_base: Some(mock_server.uri()), + api_key: Some("key".to_string()), + ..Default::default() + }) + .unwrap(); + + let err = gateway + .rerank(RerankParams { + model_id: "cohere:rerank-v3.5".to_string(), + query: "test".to_string(), + documents: vec!["doc".to_string()], + top_n: None, + max_tokens_per_doc: None, + user: None, + }) + .await + .unwrap_err(); + + assert!(matches!(err, AnyLLMError::RateLimit { .. })); +} + +#[tokio::test] +async fn test_rerank_api_function() { + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/rerank")) + .respond_with(ResponseTemplate::new(200).set_body_string(rerank_response_json())) + .mount(&mock_server) + .await; + + let result = rerank::( + "cohere:rerank-v3.5", + "test query", + vec!["doc1".to_string(), "doc2".to_string()], + RerankOptions { + api_base: Some(mock_server.uri()), + api_key: Some("test-key".to_string()), + top_n: Some(2), + ..Default::default() + }, + ) + .await + .unwrap(); + + assert_eq!(result.id, "rerank-test-123"); +} + +#[test] +fn test_gateway_supports_rerank() { + const { assert!(Gateway::SUPPORTS_RERANK) }; +} + +// --------------------------------------------------------------------------- +// Live integration tests (require a running gateway) +// --------------------------------------------------------------------------- + #[tokio::test] #[ignore = "requires a running gateway server"] async fn live_gateway_streaming() { diff --git a/tests/test_rerank.rs b/tests/test_rerank.rs new file mode 100644 index 0000000..060700f --- /dev/null +++ b/tests/test_rerank.rs @@ -0,0 +1,74 @@ +use any_llm::types::{RerankMeta, RerankParams, RerankResponse, RerankResult, RerankUsage}; +use pretty_assertions::assert_eq; +use std::collections::HashMap; + +#[test] +fn test_rerank_response_serde_roundtrip() { + let response = RerankResponse { + id: "rerank-123".to_string(), + results: vec![ + RerankResult { + index: 0, + relevance_score: 0.95, + }, + RerankResult { + index: 2, + relevance_score: 0.80, + }, + ], + meta: Some(RerankMeta { + billed_units: Some(HashMap::from([("search_units".to_string(), 1.0)])), + tokens: Some(HashMap::from([("input_tokens".to_string(), 100)])), + }), + usage: Some(RerankUsage { + total_tokens: Some(100), + }), + }; + + let json = serde_json::to_string(&response).unwrap(); + let deserialized: RerankResponse = serde_json::from_str(&json).unwrap(); + assert_eq!(response, deserialized); +} + +#[test] +fn test_rerank_response_minimal() { + let json = r#"{"id": "r-1", "results": []}"#; + let response: RerankResponse = serde_json::from_str(json).unwrap(); + assert_eq!(response.id, "r-1"); + assert!(response.results.is_empty()); + assert!(response.meta.is_none()); + assert!(response.usage.is_none()); +} + +#[test] +fn test_rerank_params_serialization() { + let params = RerankParams { + model_id: "cohere:rerank-v3.5".to_string(), + query: "test query".to_string(), + documents: vec!["doc1".to_string(), "doc2".to_string()], + top_n: Some(2), + max_tokens_per_doc: None, + user: None, + }; + let json = serde_json::to_string(¶ms).unwrap(); + assert!(json.contains("\"model_id\"")); + assert!(json.contains("\"top_n\":2")); + assert!(!json.contains("max_tokens_per_doc")); + assert!(!json.contains("\"user\"")); +} + +#[test] +fn test_rerank_result_ordering() { + let json = r#"{ + "id": "r-1", + "results": [ + {"index": 2, "relevance_score": 0.30}, + {"index": 0, "relevance_score": 0.95}, + {"index": 1, "relevance_score": 0.80} + ] + }"#; + let response: RerankResponse = serde_json::from_str(json).unwrap(); + // The SDK deserializes as-is. Sorting is the server's responsibility. + assert_eq!(response.results[0].index, 2); + assert_eq!(response.results[1].index, 0); +}