Skip to content

Commit 02993ad

Browse files
committed
feat: Add support for conversational search
1 parent 5e84bdf commit 02993ad

File tree

5 files changed

+403
-1
lines changed

5 files changed

+403
-1
lines changed

src/chats.rs

Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
use std::collections::BTreeMap;
2+
3+
use serde::{Deserialize, Serialize};
4+
use serde_json::Value;
5+
6+
use crate::{
7+
client::Client,
8+
errors::Error,
9+
request::{HttpClient, Method},
10+
};
11+
12+
/// Representation of a chat workspace.
13+
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
14+
#[serde(rename_all = "camelCase")]
15+
pub struct ChatWorkspace {
16+
pub uid: String,
17+
}
18+
19+
/// Paginated chat workspace results.
20+
#[derive(Debug, Clone, Deserialize, Serialize)]
21+
#[serde(rename_all = "camelCase")]
22+
pub struct ChatWorkspacesResults {
23+
pub results: Vec<ChatWorkspace>,
24+
pub offset: u32,
25+
pub limit: u32,
26+
pub total: u32,
27+
}
28+
29+
/// Chat workspace prompts payload.
30+
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
31+
#[serde(rename_all = "camelCase")]
32+
pub struct ChatPrompts {
33+
#[serde(skip_serializing_if = "Option::is_none")]
34+
pub system: Option<String>,
35+
#[serde(skip_serializing_if = "Option::is_none")]
36+
pub search_description: Option<String>,
37+
#[serde(rename = "searchQParam", skip_serializing_if = "Option::is_none")]
38+
pub search_q_param: Option<String>,
39+
#[serde(
40+
rename = "searchIndexUidParam",
41+
skip_serializing_if = "Option::is_none"
42+
)]
43+
pub search_index_uid_param: Option<String>,
44+
/// Any additional provider-specific prompt values.
45+
#[serde(default, flatten, skip_serializing_if = "BTreeMap::is_empty")]
46+
pub extra: BTreeMap<String, String>,
47+
}
48+
49+
impl ChatPrompts {
50+
#[must_use]
51+
pub fn new() -> Self {
52+
Self::default()
53+
}
54+
55+
pub fn set_system(&mut self, value: impl Into<String>) -> &mut Self {
56+
self.system = Some(value.into());
57+
self
58+
}
59+
60+
pub fn set_search_description(&mut self, value: impl Into<String>) -> &mut Self {
61+
self.search_description = Some(value.into());
62+
self
63+
}
64+
65+
pub fn set_search_q_param(&mut self, value: impl Into<String>) -> &mut Self {
66+
self.search_q_param = Some(value.into());
67+
self
68+
}
69+
70+
pub fn set_search_index_uid_param(&mut self, value: impl Into<String>) -> &mut Self {
71+
self.search_index_uid_param = Some(value.into());
72+
self
73+
}
74+
75+
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
76+
self.extra.insert(key.into(), value.into());
77+
self
78+
}
79+
}
80+
81+
/// Chat workspace settings payload.
82+
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
83+
#[serde(rename_all = "camelCase")]
84+
pub struct ChatWorkspaceSettings {
85+
#[serde(skip_serializing_if = "Option::is_none")]
86+
pub source: Option<String>,
87+
#[serde(skip_serializing_if = "Option::is_none")]
88+
pub org_id: Option<String>,
89+
#[serde(skip_serializing_if = "Option::is_none")]
90+
pub project_id: Option<String>,
91+
#[serde(skip_serializing_if = "Option::is_none")]
92+
pub api_version: Option<String>,
93+
#[serde(skip_serializing_if = "Option::is_none")]
94+
pub deployment_id: Option<String>,
95+
#[serde(skip_serializing_if = "Option::is_none")]
96+
pub base_url: Option<String>,
97+
#[serde(skip_serializing_if = "Option::is_none")]
98+
pub api_key: Option<String>,
99+
#[serde(skip_serializing_if = "Option::is_none")]
100+
pub prompts: Option<ChatPrompts>,
101+
}
102+
103+
impl ChatWorkspaceSettings {
104+
#[must_use]
105+
pub fn new() -> Self {
106+
Self::default()
107+
}
108+
109+
pub fn set_source(&mut self, source: impl Into<String>) -> &mut Self {
110+
self.source = Some(source.into());
111+
self
112+
}
113+
114+
pub fn set_org_id(&mut self, org_id: impl Into<String>) -> &mut Self {
115+
self.org_id = Some(org_id.into());
116+
self
117+
}
118+
119+
pub fn set_project_id(&mut self, project_id: impl Into<String>) -> &mut Self {
120+
self.project_id = Some(project_id.into());
121+
self
122+
}
123+
124+
pub fn set_api_version(&mut self, api_version: impl Into<String>) -> &mut Self {
125+
self.api_version = Some(api_version.into());
126+
self
127+
}
128+
129+
pub fn set_deployment_id(&mut self, deployment_id: impl Into<String>) -> &mut Self {
130+
self.deployment_id = Some(deployment_id.into());
131+
self
132+
}
133+
134+
pub fn set_base_url(&mut self, base_url: impl Into<String>) -> &mut Self {
135+
self.base_url = Some(base_url.into());
136+
self
137+
}
138+
139+
pub fn set_api_key(&mut self, api_key: impl Into<String>) -> &mut Self {
140+
self.api_key = Some(api_key.into());
141+
self
142+
}
143+
144+
pub fn set_prompts(&mut self, prompts: ChatPrompts) -> &mut Self {
145+
self.prompts = Some(prompts);
146+
self
147+
}
148+
}
149+
150+
/// Query builder for listing chat workspaces.
151+
#[derive(Debug, Serialize)]
152+
pub struct ChatWorkspacesQuery<'a, Http: HttpClient> {
153+
#[serde(skip_serializing)]
154+
pub client: &'a Client<Http>,
155+
#[serde(skip_serializing_if = "Option::is_none")]
156+
pub offset: Option<usize>,
157+
#[serde(skip_serializing_if = "Option::is_none")]
158+
pub limit: Option<usize>,
159+
}
160+
161+
impl<'a, Http: HttpClient> ChatWorkspacesQuery<'a, Http> {
162+
#[must_use]
163+
pub fn new(client: &'a Client<Http>) -> Self {
164+
Self {
165+
client,
166+
offset: None,
167+
limit: None,
168+
}
169+
}
170+
171+
pub fn with_offset(&mut self, offset: usize) -> &mut Self {
172+
self.offset = Some(offset);
173+
self
174+
}
175+
176+
pub fn with_limit(&mut self, limit: usize) -> &mut Self {
177+
self.limit = Some(limit);
178+
self
179+
}
180+
181+
pub async fn execute(&self) -> Result<ChatWorkspacesResults, Error> {
182+
self.client.list_chat_workspaces_with(self).await
183+
}
184+
}
185+
186+
impl<Http: HttpClient> Client<Http> {
187+
/// List all chat workspaces.
188+
pub async fn list_chat_workspaces(&self) -> Result<ChatWorkspacesResults, Error> {
189+
self.http_client
190+
.request::<(), (), ChatWorkspacesResults>(
191+
&format!("{}/chats", self.host),
192+
Method::Get { query: () },
193+
200,
194+
)
195+
.await
196+
}
197+
198+
/// List chat workspaces using query parameters.
199+
pub async fn list_chat_workspaces_with(
200+
&self,
201+
query: &ChatWorkspacesQuery<'_, Http>,
202+
) -> Result<ChatWorkspacesResults, Error> {
203+
self.http_client
204+
.request::<&ChatWorkspacesQuery<'_, Http>, (), ChatWorkspacesResults>(
205+
&format!("{}/chats", self.host),
206+
Method::Get { query },
207+
200,
208+
)
209+
.await
210+
}
211+
212+
/// Retrieve a chat workspace by uid.
213+
pub async fn get_chat_workspace(&self, uid: impl AsRef<str>) -> Result<ChatWorkspace, Error> {
214+
self.http_client
215+
.request::<(), (), ChatWorkspace>(
216+
&format!("{}/chats/{}", self.host, uid.as_ref()),
217+
Method::Get { query: () },
218+
200,
219+
)
220+
.await
221+
}
222+
223+
/// Retrieve chat workspace settings.
224+
pub async fn get_chat_workspace_settings(
225+
&self,
226+
uid: impl AsRef<str>,
227+
) -> Result<ChatWorkspaceSettings, Error> {
228+
self.http_client
229+
.request::<(), (), ChatWorkspaceSettings>(
230+
&format!("{}/chats/{}/settings", self.host, uid.as_ref()),
231+
Method::Get { query: () },
232+
200,
233+
)
234+
.await
235+
}
236+
237+
/// Create or update chat workspace settings.
238+
pub async fn update_chat_workspace_settings(
239+
&self,
240+
uid: impl AsRef<str>,
241+
settings: &ChatWorkspaceSettings,
242+
) -> Result<ChatWorkspaceSettings, Error> {
243+
self.http_client
244+
.request::<(), &ChatWorkspaceSettings, ChatWorkspaceSettings>(
245+
&format!("{}/chats/{}/settings", self.host, uid.as_ref()),
246+
Method::Patch {
247+
query: (),
248+
body: settings,
249+
},
250+
200,
251+
)
252+
.await
253+
}
254+
255+
/// Reset chat workspace settings to defaults.
256+
pub async fn reset_chat_workspace_settings(
257+
&self,
258+
uid: impl AsRef<str>,
259+
) -> Result<ChatWorkspaceSettings, Error> {
260+
self.http_client
261+
.request::<(), (), ChatWorkspaceSettings>(
262+
&format!("{}/chats/{}/settings", self.host, uid.as_ref()),
263+
Method::Delete { query: () },
264+
200,
265+
)
266+
.await
267+
}
268+
}
269+
270+
#[cfg(feature = "reqwest")]
271+
impl Client<crate::reqwest::ReqwestClient> {
272+
/// Stream chat completions for a workspace.
273+
pub async fn stream_chat_completion<S: Serialize + ?Sized>(
274+
&self,
275+
uid: impl AsRef<str>,
276+
body: &S,
277+
) -> Result<reqwest::Response, Error> {
278+
use reqwest::header::{HeaderValue, ACCEPT, CONTENT_TYPE};
279+
280+
let payload = serde_json::to_vec(body).map_err(Error::ParseError)?;
281+
282+
let response = self
283+
.http_client
284+
.inner()
285+
.post(format!(
286+
"{}/chats/{}/chat/completions",
287+
self.host,
288+
uid.as_ref()
289+
))
290+
.header(ACCEPT, HeaderValue::from_static("text/event-stream"))
291+
.header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
292+
.body(payload)
293+
.send()
294+
.await?;
295+
296+
let status = response.status();
297+
if !status.is_success() {
298+
let url = response.url().to_string();
299+
let mut body = response.text().await?;
300+
if body.is_empty() {
301+
body = "null".to_string();
302+
}
303+
let err =
304+
match crate::request::parse_response::<Value>(status.as_u16(), 200, &body, url) {
305+
Ok(_) => unreachable!("parse_response succeeded on a non-successful status"),
306+
Err(err) => err,
307+
};
308+
return Err(err);
309+
}
310+
311+
Ok(response)
312+
}
313+
}
314+
315+
#[cfg(test)]
316+
mod tests {
317+
use super::*;
318+
use crate::features::ExperimentalFeatures;
319+
use meilisearch_test_macro::meilisearch_test;
320+
321+
#[meilisearch_test]
322+
async fn chat_workspace_lifecycle(client: Client, name: String) -> Result<(), Error> {
323+
let mut features = ExperimentalFeatures::new(&client);
324+
features.set_chat_completions(true);
325+
let _ = features.update().await?;
326+
327+
let workspace = format!("{name}-workspace");
328+
329+
let mut prompts = ChatPrompts::new();
330+
prompts.set_system("You are a helpful assistant.");
331+
prompts.set_search_description("Use search to fetch relevant documents.");
332+
333+
let mut settings = ChatWorkspaceSettings::new();
334+
settings
335+
.set_source("openAi")
336+
.set_api_key("sk-test")
337+
.set_prompts(prompts.clone());
338+
339+
let updated = client
340+
.update_chat_workspace_settings(&workspace, &settings)
341+
.await?;
342+
assert_eq!(updated.source.as_deref(), Some("openAi"));
343+
let updated_prompts = updated
344+
.prompts
345+
.expect("updated settings should contain prompts");
346+
assert_eq!(updated_prompts.system.as_deref(), prompts.system.as_deref());
347+
assert_eq!(
348+
updated_prompts.search_description.as_deref(),
349+
prompts.search_description.as_deref()
350+
);
351+
if let Some(masked_key) = updated.api_key.as_ref() {
352+
assert_ne!(
353+
masked_key, "sk-test",
354+
"API key should not be returned in clear text"
355+
);
356+
}
357+
358+
let workspace_info = client.get_chat_workspace(&workspace).await?;
359+
assert_eq!(workspace_info.uid, workspace);
360+
361+
let fetched_settings = client.get_chat_workspace_settings(&workspace).await?;
362+
assert_eq!(fetched_settings.source.as_deref(), Some("openAi"));
363+
let fetched_prompts = fetched_settings
364+
.prompts
365+
.expect("workspace should have prompts configured");
366+
assert_eq!(fetched_prompts.system.as_deref(), prompts.system.as_deref());
367+
assert_eq!(
368+
fetched_prompts.search_description.as_deref(),
369+
prompts.search_description.as_deref()
370+
);
371+
372+
let list = client.list_chat_workspaces().await?;
373+
assert!(list.results.iter().any(|w| w.uid == workspace));
374+
375+
let mut query = ChatWorkspacesQuery::new(&client);
376+
query.with_limit(1);
377+
let limited = query.execute().await?;
378+
assert_eq!(limited.limit, 1);
379+
380+
let _ = client.reset_chat_workspace_settings(&workspace).await?;
381+
382+
Ok(())
383+
}
384+
}

src/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1296,7 +1296,7 @@ mod tests {
12961296

12971297
use meilisearch_test_macro::meilisearch_test;
12981298

1299-
use crate::{client::*, key::Action, reqwest::qualified_version};
1299+
use crate::{key::Action, reqwest::qualified_version};
13001300

13011301
#[derive(Debug, Serialize, Deserialize, PartialEq)]
13021302
struct Document {

0 commit comments

Comments
 (0)