Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
use crate::error::Result;
use crate::provider::{AnyLLMProvider, CompletionStream, ProviderConfig};
use crate::types::{
ChatCompletion, CompletionParams, Message, ReasoningEffort, StopSequence, Tool, ToolChoice,
ChatCompletion, CompletionParams, Message, ReasoningEffort, RerankParams, RerankResponse,
StopSequence, Tool, ToolChoice,
};
use crate::Provider;
use serde_json::Value;
Expand Down Expand Up @@ -308,6 +309,66 @@ pub async fn completion_stream<P: Provider>(
provider.completion_stream(params).await
}

/// Options for a rerank request.
#[derive(Debug, Clone, Default)]
pub struct RerankOptions {
/// API key (if not set, uses environment variable).
pub api_key: Option<String>,

/// API base URL (for custom endpoints/proxies).
pub api_base: Option<String>,

/// Maximum number of results to return.
pub top_n: Option<u32>,

/// Maximum tokens per document for truncation.
pub max_tokens_per_doc: Option<u32>,

/// User identifier for abuse detection.
pub user: Option<String>,
}

impl From<RerankOptions> for ProviderConfig {
fn from(options: RerankOptions) -> Self {
ProviderConfig {
api_key: options.api_key,
api_base: options.api_base,
extra: Default::default(),
}
}
}

/// Rerank documents by relevance to a query.
///
/// # Arguments
/// * `model` - Model identifier (e.g., "cohere:rerank-v3.5")
/// * `query` - The search query
/// * `documents` - Documents to rerank
/// * `options` - Additional options (API key, base URL, top_n, etc.)
///
/// # Returns
/// A `RerankResponse` with results sorted by `relevance_score` descending.
///
/// # Errors
/// Returns `AnyLLMError` if the provider does not support reranking or the request fails.
pub async fn rerank<P: Provider>(
model: &str,
query: &str,
documents: Vec<String>,
options: RerankOptions,
) -> Result<RerankResponse> {
let provider = AnyLLMProvider::<P>::from_config(options.clone().into())?;
let params = RerankParams {
model_id: model.to_string(),
query: query.to_string(),
documents,
top_n: options.top_n,
max_tokens_per_doc: options.max_tokens_per_doc,
user: options.user,
};
provider.rerank(params).await
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ pub mod providers;
pub mod types;

// Re-export main types for convenience
pub use api::{completion, completion_stream, CompletionOptions};
pub use api::{completion, completion_stream, rerank, CompletionOptions, RerankOptions};
pub use error::{AnyLLMError, Result};
pub use provider::{Provider, ProviderConfig};
pub use types::{
Expand All @@ -146,6 +146,6 @@ pub use types::{
ChunkAccumulator, ChunkChoice, CompletionParams, CompletionUsage, Content, ContentPart,
CreateBatchParams, Function, ImageUrl, ListBatchesOptions, Message, ModerationContentPart,
ModerationImageUrl, ModerationInput, ModerationParams, ModerationResponse, ModerationResult,
Reasoning, ReasoningEffort, Role, StopSequence, Tool, ToolCall, ToolCallDelta, ToolChoice,
ToolFunction,
Reasoning, ReasoningEffort, RerankMeta, RerankParams, RerankResponse, RerankResult,
RerankUsage, Role, StopSequence, Tool, ToolCall, ToolCallDelta, ToolChoice, ToolFunction,
};
31 changes: 30 additions & 1 deletion src/provider/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use futures::Future;
use crate::{
error::{AnyLLMError, Result},
types::Content,
types::{ChatCompletion, CompletionParams},
types::{ChatCompletion, CompletionParams, RerankParams, RerankResponse},
};

use super::{config::ProviderConfig, CompletionStream};
Expand All @@ -21,6 +21,7 @@ pub trait Provider: Sized + Send + Sync {
const SUPPORTS_IMAGES: bool = false;
const SUPPORTS_REASONING: bool = false;
const SUPPORTS_PDF: bool = false;
const SUPPORTS_RERANK: bool = false;

fn api_key(config: &ProviderConfig) -> Option<String> {
config
Expand Down Expand Up @@ -106,4 +107,32 @@ pub trait Provider: Sized + Send + Sync {
self.completion_stream_fn(params).await
}
}

// ── Rerank ──────────────────────────────────────────────────

/// Rerank documents by relevance to a query.
/// Providers that support reranking must override `rerank_fn`.
fn rerank_fn(
&self,
_params: RerankParams,
) -> impl Future<Output = Result<RerankResponse>> + Send {
async {
Err(AnyLLMError::invalid_request::<Self>(
"rerank not supported by this provider",
))
}
}

/// Rerank documents by relevance to a query (with validation).
fn rerank(&self, params: RerankParams) -> impl Future<Output = Result<RerankResponse>> + Send {
async {
if !Self::SUPPORTS_RERANK {
return Err(AnyLLMError::unsupported_parameter::<Self>(
"rerank",
"this provider does not support reranking",
));
}
self.rerank_fn(params).await
}
}
}
7 changes: 6 additions & 1 deletion src/provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use futures::Stream;

use crate::{
error::Result,
types::{ChatCompletion, ChatCompletionChunk, CompletionParams},
types::{ChatCompletion, ChatCompletionChunk, CompletionParams, RerankParams, RerankResponse},
};

mod config;
Expand Down Expand Up @@ -35,4 +35,9 @@ impl<P: Provider> AnyLLMProvider<P> {
pub async fn completion_stream(&self, params: CompletionParams) -> Result<CompletionStream> {
self.0.completion_stream(params).await
}

/// Rerank documents by relevance to a query.
pub async fn rerank(&self, params: RerankParams) -> Result<RerankResponse> {
self.0.rerank(params).await
}
}
25 changes: 24 additions & 1 deletion src/providers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::error::{AnyLLMError, Result};
use crate::provider::{CompletionStream, Provider, ProviderConfig};
use crate::types::{
Batch, BatchResult, ChatCompletion, CompletionParams, CreateBatchParams, ListBatchesOptions,
ModerationParams, ModerationResponse,
ModerationParams, ModerationResponse, RerankParams, RerankResponse,
};

mod models;
Expand Down Expand Up @@ -202,6 +202,7 @@ impl Provider for Gateway {
const SUPPORTS_IMAGES: bool = true;
const SUPPORTS_REASONING: bool = true;
const SUPPORTS_PDF: bool = true;
const SUPPORTS_RERANK: bool = true;

fn from_config(config: ProviderConfig) -> Result<Self> {
let api_base = config
Expand Down Expand Up @@ -274,6 +275,28 @@ impl Provider for Gateway {

GatewayStream::new(es, model).try_into()
}

async fn rerank_fn(&self, params: RerankParams) -> Result<RerankResponse> {
let body = models::rerank::GatewayRerankRequest::from(params);

let response = self
.client
.post(format!("{}/v1/rerank", self.api_base))
.json(&body)
.send()
.await
.map_err(AnyLLMError::from)?;

let status = response.status().as_u16();
if status != 200 {
return Err(convert_error(response).await);
}

response
.json::<RerankResponse>()
.await
.map_err(AnyLLMError::from)
}
}

/// Resolve auth mode and build the appropriate HTTP headers.
Expand Down
1 change: 1 addition & 0 deletions src/providers/gateway/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod request;
pub mod rerank;
pub mod response;
pub mod stream;
68 changes: 68 additions & 0 deletions src/providers/gateway/models/rerank.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use serde::Serialize;

use crate::types::RerankParams;

/// Gateway wire format for a rerank request.
#[derive(Debug, Serialize)]
pub struct GatewayRerankRequest {
pub model: String,
pub query: String,
pub documents: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens_per_doc: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}

impl From<RerankParams> for GatewayRerankRequest {
fn from(params: RerankParams) -> Self {
Self {
model: params.model_id,
query: params.query,
documents: params.documents,
top_n: params.top_n,
max_tokens_per_doc: params.max_tokens_per_doc,
user: params.user,
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_from_rerank_params() {
let params = RerankParams {
model_id: "cohere:rerank-v3.5".to_string(),
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
top_n: Some(2),
max_tokens_per_doc: None,
user: None,
};
let req = GatewayRerankRequest::from(params);
assert_eq!(req.model, "cohere:rerank-v3.5");
assert_eq!(req.query, "test query");
assert_eq!(req.documents.len(), 2);
assert_eq!(req.top_n, Some(2));
}

#[test]
fn test_serialization_skips_none() {
let req = GatewayRerankRequest {
model: "test".to_string(),
query: "q".to_string(),
documents: vec!["d".to_string()],
top_n: None,
max_tokens_per_doc: None,
user: None,
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("top_n"));
assert!(!json.contains("max_tokens_per_doc"));
assert!(!json.contains("user"));
}
}
2 changes: 2 additions & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod chunk;
mod completion;
mod message;
mod moderation;
mod rerank;
mod tool;
mod usage;

Expand All @@ -17,5 +18,6 @@ pub use chunk::*;
pub use completion::*;
pub use message::*;
pub use moderation::*;
pub use rerank::*;
pub use tool::*;
pub use usage::*;
50 changes: 50 additions & 0 deletions src/types/rerank.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// Parameters for a rerank request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankParams {
pub model_id: String,
pub query: String,
pub documents: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens_per_doc: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}

/// A single reranked document score.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RerankResult {
pub index: u32,
pub relevance_score: f64,
}

/// Provider-specific billing metadata.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RerankMeta {
#[serde(skip_serializing_if = "Option::is_none")]
pub billed_units: Option<HashMap<String, f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tokens: Option<HashMap<String, u32>>,
}

/// Normalized token usage for a rerank request.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RerankUsage {
#[serde(skip_serializing_if = "Option::is_none")]
pub total_tokens: Option<u32>,
}

/// Normalized rerank response, provider-agnostic.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RerankResponse {
pub id: String,
pub results: Vec<RerankResult>,
#[serde(skip_serializing_if = "Option::is_none")]
pub meta: Option<RerankMeta>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<RerankUsage>,
}
Loading
Loading