-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinference.rs
149 lines (130 loc) · 4.62 KB
/
inference.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
use crate::openapi::apis::inference_api;
use crate::openapi::models::{EmbedRequest, EmbedRequestInputsInner};
use crate::pinecone::PineconeClient;
use crate::utils::errors::PineconeError;
use crate::models::{EmbedRequestParameters, EmbeddingsList};
impl PineconeClient {
/// Generate embeddings for input data.
///
/// ### Arguments
/// * `model: &str` - The model to use for embedding.
/// * `parameters: Option<EmbedRequestParameters>` - Model-specific parameters.
/// * `inputs: &Vec<&str>` - The input data to embed.
///
/// ### Return
/// * `Result<EmbeddingsList, PineconeError>`
///
/// ### Example
/// ```no_run
///
/// # #[tokio::main]
/// # async fn main() -> Result<(), pinecone_sdk::utils::errors::PineconeError> {
///
/// let pinecone = pinecone_sdk::pinecone::default_client()?;
/// let response = pinecone.embed("multilingual-e5-large", None, &vec!["Hello, world!"]).await.expect("Failed to embed");
///
/// # Ok(())
/// # }
/// ```
pub async fn embed(
&self,
model: &str,
parameters: Option<EmbedRequestParameters>,
inputs: &Vec<&str>,
) -> Result<EmbeddingsList, PineconeError> {
let request = EmbedRequest {
model: model.to_string(),
parameters: parameters.map(|x| Box::new(x)),
inputs: inputs
.iter()
.map(|&x| EmbedRequestInputsInner {
text: Some(x.to_string()),
})
.collect(),
};
let res = inference_api::embed(&self.openapi_config, Some(request))
.await
.map_err(|e| PineconeError::from(e))?;
Ok(res.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pinecone::PineconeClientConfig;
use httpmock::prelude::*;
use tokio;
#[tokio::test]
async fn test_embed() -> Result<(), PineconeError> {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(POST).path("/embed");
then.status(200)
.header("content-type", "application/json")
.body(
r#"
{
"model": "multilingual-e5-large",
"data": [
{"values": [0.01849365234375, -0.003767013549804688, -0.037261962890625, 0.0222930908203125]}
],
"usage": {"total_tokens": 1632}
}
"#,
);
});
let config = PineconeClientConfig {
control_plane_host: Some(server.base_url()),
..Default::default()
};
let pinecone = config.client().expect("Failed to create Pinecone instance");
let response = pinecone
.embed("multilingual-e5-large", None, &vec!["Hello, world!"])
.await
.expect("Failed to embed");
mock.assert();
assert_eq!(response.model, "multilingual-e5-large");
assert_eq!(response.data.len(), 1);
assert_eq!(response.usage.total_tokens, 1632);
Ok(())
}
#[tokio::test]
async fn test_embed_invalid_arguments() -> Result<(), PineconeError> {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(POST).path("/embed");
then.status(400)
.header("content-type", "application/json")
.body(
r#"
{
"error": {
"code": "INVALID_ARGUMENT",
"message": "Invalid parameter value input_type='bad-parameter' for model 'multilingual-e5-large', must be one of [query, passage]"
},
"status": 400
}
"#,
);
});
let config = PineconeClientConfig {
control_plane_host: Some(server.base_url()),
..Default::default()
};
let pinecone = config.client().expect("Failed to create Pinecone instance");
let parameters = EmbedRequestParameters {
input_type: Some("bad-parameter".to_string()),
truncate: Some("bad-parameter".to_string()),
};
let _ = pinecone
.embed(
"multilingual-e5-large",
Some(parameters),
&vec!["Hello, world!"],
)
.await
.expect_err("Expected to fail embedding with invalid arguments");
mock.assert();
Ok(())
}
}