|
| 1 | +use std::collections::BTreeMap; |
| 2 | + |
| 3 | +use serde::{Deserialize, Serialize}; |
| 4 | +use serde_json::Value; |
| 5 | + |
| 6 | +use crate::{ |
| 7 | + client::Client, |
| 8 | + errors::Error, |
| 9 | + request::{HttpClient, Method}, |
| 10 | +}; |
| 11 | + |
| 12 | +/// Representation of a chat workspace. |
| 13 | +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] |
| 14 | +#[serde(rename_all = "camelCase")] |
| 15 | +pub struct ChatWorkspace { |
| 16 | + pub uid: String, |
| 17 | +} |
| 18 | + |
| 19 | +/// Paginated chat workspace results. |
| 20 | +#[derive(Debug, Clone, Deserialize, Serialize)] |
| 21 | +#[serde(rename_all = "camelCase")] |
| 22 | +pub struct ChatWorkspacesResults { |
| 23 | + pub results: Vec<ChatWorkspace>, |
| 24 | + pub offset: u32, |
| 25 | + pub limit: u32, |
| 26 | + pub total: u32, |
| 27 | +} |
| 28 | + |
| 29 | +/// Chat workspace prompts payload. |
| 30 | +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] |
| 31 | +#[serde(rename_all = "camelCase")] |
| 32 | +pub struct ChatPrompts { |
| 33 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 34 | + pub system: Option<String>, |
| 35 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 36 | + pub search_description: Option<String>, |
| 37 | + #[serde(rename = "searchQParam", skip_serializing_if = "Option::is_none")] |
| 38 | + pub search_q_param: Option<String>, |
| 39 | + #[serde( |
| 40 | + rename = "searchIndexUidParam", |
| 41 | + skip_serializing_if = "Option::is_none" |
| 42 | + )] |
| 43 | + pub search_index_uid_param: Option<String>, |
| 44 | + /// Any additional provider-specific prompt values. |
| 45 | + #[serde(default, flatten, skip_serializing_if = "BTreeMap::is_empty")] |
| 46 | + pub extra: BTreeMap<String, String>, |
| 47 | +} |
| 48 | + |
| 49 | +impl ChatPrompts { |
| 50 | + #[must_use] |
| 51 | + pub fn new() -> Self { |
| 52 | + Self::default() |
| 53 | + } |
| 54 | + |
| 55 | + pub fn set_system(&mut self, value: impl Into<String>) -> &mut Self { |
| 56 | + self.system = Some(value.into()); |
| 57 | + self |
| 58 | + } |
| 59 | + |
| 60 | + pub fn set_search_description(&mut self, value: impl Into<String>) -> &mut Self { |
| 61 | + self.search_description = Some(value.into()); |
| 62 | + self |
| 63 | + } |
| 64 | + |
| 65 | + pub fn set_search_q_param(&mut self, value: impl Into<String>) -> &mut Self { |
| 66 | + self.search_q_param = Some(value.into()); |
| 67 | + self |
| 68 | + } |
| 69 | + |
| 70 | + pub fn set_search_index_uid_param(&mut self, value: impl Into<String>) -> &mut Self { |
| 71 | + self.search_index_uid_param = Some(value.into()); |
| 72 | + self |
| 73 | + } |
| 74 | + |
| 75 | + pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self { |
| 76 | + self.extra.insert(key.into(), value.into()); |
| 77 | + self |
| 78 | + } |
| 79 | +} |
| 80 | + |
| 81 | +/// Chat workspace settings payload. |
| 82 | +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] |
| 83 | +#[serde(rename_all = "camelCase")] |
| 84 | +pub struct ChatWorkspaceSettings { |
| 85 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 86 | + pub source: Option<String>, |
| 87 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 88 | + pub org_id: Option<String>, |
| 89 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 90 | + pub project_id: Option<String>, |
| 91 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 92 | + pub api_version: Option<String>, |
| 93 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 94 | + pub deployment_id: Option<String>, |
| 95 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 96 | + pub base_url: Option<String>, |
| 97 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 98 | + pub api_key: Option<String>, |
| 99 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 100 | + pub prompts: Option<ChatPrompts>, |
| 101 | +} |
| 102 | + |
| 103 | +impl ChatWorkspaceSettings { |
| 104 | + #[must_use] |
| 105 | + pub fn new() -> Self { |
| 106 | + Self::default() |
| 107 | + } |
| 108 | + |
| 109 | + pub fn set_source(&mut self, source: impl Into<String>) -> &mut Self { |
| 110 | + self.source = Some(source.into()); |
| 111 | + self |
| 112 | + } |
| 113 | + |
| 114 | + pub fn set_org_id(&mut self, org_id: impl Into<String>) -> &mut Self { |
| 115 | + self.org_id = Some(org_id.into()); |
| 116 | + self |
| 117 | + } |
| 118 | + |
| 119 | + pub fn set_project_id(&mut self, project_id: impl Into<String>) -> &mut Self { |
| 120 | + self.project_id = Some(project_id.into()); |
| 121 | + self |
| 122 | + } |
| 123 | + |
| 124 | + pub fn set_api_version(&mut self, api_version: impl Into<String>) -> &mut Self { |
| 125 | + self.api_version = Some(api_version.into()); |
| 126 | + self |
| 127 | + } |
| 128 | + |
| 129 | + pub fn set_deployment_id(&mut self, deployment_id: impl Into<String>) -> &mut Self { |
| 130 | + self.deployment_id = Some(deployment_id.into()); |
| 131 | + self |
| 132 | + } |
| 133 | + |
| 134 | + pub fn set_base_url(&mut self, base_url: impl Into<String>) -> &mut Self { |
| 135 | + self.base_url = Some(base_url.into()); |
| 136 | + self |
| 137 | + } |
| 138 | + |
| 139 | + pub fn set_api_key(&mut self, api_key: impl Into<String>) -> &mut Self { |
| 140 | + self.api_key = Some(api_key.into()); |
| 141 | + self |
| 142 | + } |
| 143 | + |
| 144 | + pub fn set_prompts(&mut self, prompts: ChatPrompts) -> &mut Self { |
| 145 | + self.prompts = Some(prompts); |
| 146 | + self |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +/// Query builder for listing chat workspaces. |
| 151 | +#[derive(Debug, Serialize)] |
| 152 | +pub struct ChatWorkspacesQuery<'a, Http: HttpClient> { |
| 153 | + #[serde(skip_serializing)] |
| 154 | + pub client: &'a Client<Http>, |
| 155 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 156 | + pub offset: Option<usize>, |
| 157 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 158 | + pub limit: Option<usize>, |
| 159 | +} |
| 160 | + |
| 161 | +impl<'a, Http: HttpClient> ChatWorkspacesQuery<'a, Http> { |
| 162 | + #[must_use] |
| 163 | + pub fn new(client: &'a Client<Http>) -> Self { |
| 164 | + Self { |
| 165 | + client, |
| 166 | + offset: None, |
| 167 | + limit: None, |
| 168 | + } |
| 169 | + } |
| 170 | + |
| 171 | + pub fn with_offset(&mut self, offset: usize) -> &mut Self { |
| 172 | + self.offset = Some(offset); |
| 173 | + self |
| 174 | + } |
| 175 | + |
| 176 | + pub fn with_limit(&mut self, limit: usize) -> &mut Self { |
| 177 | + self.limit = Some(limit); |
| 178 | + self |
| 179 | + } |
| 180 | + |
| 181 | + pub async fn execute(&self) -> Result<ChatWorkspacesResults, Error> { |
| 182 | + self.client.list_chat_workspaces_with(self).await |
| 183 | + } |
| 184 | +} |
| 185 | + |
| 186 | +impl<Http: HttpClient> Client<Http> { |
| 187 | + /// List all chat workspaces. |
| 188 | + pub async fn list_chat_workspaces(&self) -> Result<ChatWorkspacesResults, Error> { |
| 189 | + self.http_client |
| 190 | + .request::<(), (), ChatWorkspacesResults>( |
| 191 | + &format!("{}/chats", self.host), |
| 192 | + Method::Get { query: () }, |
| 193 | + 200, |
| 194 | + ) |
| 195 | + .await |
| 196 | + } |
| 197 | + |
| 198 | + /// List chat workspaces using query parameters. |
| 199 | + pub async fn list_chat_workspaces_with( |
| 200 | + &self, |
| 201 | + query: &ChatWorkspacesQuery<'_, Http>, |
| 202 | + ) -> Result<ChatWorkspacesResults, Error> { |
| 203 | + self.http_client |
| 204 | + .request::<&ChatWorkspacesQuery<'_, Http>, (), ChatWorkspacesResults>( |
| 205 | + &format!("{}/chats", self.host), |
| 206 | + Method::Get { query }, |
| 207 | + 200, |
| 208 | + ) |
| 209 | + .await |
| 210 | + } |
| 211 | + |
| 212 | + /// Retrieve a chat workspace by uid. |
| 213 | + pub async fn get_chat_workspace(&self, uid: impl AsRef<str>) -> Result<ChatWorkspace, Error> { |
| 214 | + self.http_client |
| 215 | + .request::<(), (), ChatWorkspace>( |
| 216 | + &format!("{}/chats/{}", self.host, uid.as_ref()), |
| 217 | + Method::Get { query: () }, |
| 218 | + 200, |
| 219 | + ) |
| 220 | + .await |
| 221 | + } |
| 222 | + |
| 223 | + /// Retrieve chat workspace settings. |
| 224 | + pub async fn get_chat_workspace_settings( |
| 225 | + &self, |
| 226 | + uid: impl AsRef<str>, |
| 227 | + ) -> Result<ChatWorkspaceSettings, Error> { |
| 228 | + self.http_client |
| 229 | + .request::<(), (), ChatWorkspaceSettings>( |
| 230 | + &format!("{}/chats/{}/settings", self.host, uid.as_ref()), |
| 231 | + Method::Get { query: () }, |
| 232 | + 200, |
| 233 | + ) |
| 234 | + .await |
| 235 | + } |
| 236 | + |
| 237 | + /// Create or update chat workspace settings. |
| 238 | + pub async fn update_chat_workspace_settings( |
| 239 | + &self, |
| 240 | + uid: impl AsRef<str>, |
| 241 | + settings: &ChatWorkspaceSettings, |
| 242 | + ) -> Result<ChatWorkspaceSettings, Error> { |
| 243 | + self.http_client |
| 244 | + .request::<(), &ChatWorkspaceSettings, ChatWorkspaceSettings>( |
| 245 | + &format!("{}/chats/{}/settings", self.host, uid.as_ref()), |
| 246 | + Method::Patch { |
| 247 | + query: (), |
| 248 | + body: settings, |
| 249 | + }, |
| 250 | + 200, |
| 251 | + ) |
| 252 | + .await |
| 253 | + } |
| 254 | + |
| 255 | + /// Reset chat workspace settings to defaults. |
| 256 | + pub async fn reset_chat_workspace_settings( |
| 257 | + &self, |
| 258 | + uid: impl AsRef<str>, |
| 259 | + ) -> Result<ChatWorkspaceSettings, Error> { |
| 260 | + self.http_client |
| 261 | + .request::<(), (), ChatWorkspaceSettings>( |
| 262 | + &format!("{}/chats/{}/settings", self.host, uid.as_ref()), |
| 263 | + Method::Delete { query: () }, |
| 264 | + 200, |
| 265 | + ) |
| 266 | + .await |
| 267 | + } |
| 268 | +} |
| 269 | + |
| 270 | +#[cfg(feature = "reqwest")] |
| 271 | +impl Client<crate::reqwest::ReqwestClient> { |
| 272 | + /// Stream chat completions for a workspace. |
| 273 | + pub async fn stream_chat_completion<S: Serialize + ?Sized>( |
| 274 | + &self, |
| 275 | + uid: impl AsRef<str>, |
| 276 | + body: &S, |
| 277 | + ) -> Result<reqwest::Response, Error> { |
| 278 | + use reqwest::header::{HeaderValue, ACCEPT, CONTENT_TYPE}; |
| 279 | + |
| 280 | + let payload = serde_json::to_vec(body).map_err(Error::ParseError)?; |
| 281 | + |
| 282 | + let response = self |
| 283 | + .http_client |
| 284 | + .inner() |
| 285 | + .post(format!( |
| 286 | + "{}/chats/{}/chat/completions", |
| 287 | + self.host, |
| 288 | + uid.as_ref() |
| 289 | + )) |
| 290 | + .header(ACCEPT, HeaderValue::from_static("text/event-stream")) |
| 291 | + .header(CONTENT_TYPE, HeaderValue::from_static("application/json")) |
| 292 | + .body(payload) |
| 293 | + .send() |
| 294 | + .await?; |
| 295 | + |
| 296 | + let status = response.status(); |
| 297 | + if !status.is_success() { |
| 298 | + let url = response.url().to_string(); |
| 299 | + let mut body = response.text().await?; |
| 300 | + if body.is_empty() { |
| 301 | + body = "null".to_string(); |
| 302 | + } |
| 303 | + let err = |
| 304 | + match crate::request::parse_response::<Value>(status.as_u16(), 200, &body, url) { |
| 305 | + Ok(_) => unreachable!("parse_response succeeded on a non-successful status"), |
| 306 | + Err(err) => err, |
| 307 | + }; |
| 308 | + return Err(err); |
| 309 | + } |
| 310 | + |
| 311 | + Ok(response) |
| 312 | + } |
| 313 | +} |
| 314 | + |
| 315 | +#[cfg(test)] |
| 316 | +mod tests { |
| 317 | + use super::*; |
| 318 | + use crate::features::ExperimentalFeatures; |
| 319 | + use meilisearch_test_macro::meilisearch_test; |
| 320 | + |
| 321 | + #[meilisearch_test] |
| 322 | + async fn chat_workspace_lifecycle(client: Client, name: String) -> Result<(), Error> { |
| 323 | + let mut features = ExperimentalFeatures::new(&client); |
| 324 | + features.set_chat_completions(true); |
| 325 | + let _ = features.update().await?; |
| 326 | + |
| 327 | + let workspace = format!("{name}-workspace"); |
| 328 | + |
| 329 | + let mut prompts = ChatPrompts::new(); |
| 330 | + prompts.set_system("You are a helpful assistant."); |
| 331 | + prompts.set_search_description("Use search to fetch relevant documents."); |
| 332 | + |
| 333 | + let mut settings = ChatWorkspaceSettings::new(); |
| 334 | + settings |
| 335 | + .set_source("openAi") |
| 336 | + .set_api_key("sk-test") |
| 337 | + .set_prompts(prompts.clone()); |
| 338 | + |
| 339 | + let updated = client |
| 340 | + .update_chat_workspace_settings(&workspace, &settings) |
| 341 | + .await?; |
| 342 | + assert_eq!(updated.source.as_deref(), Some("openAi")); |
| 343 | + let updated_prompts = updated |
| 344 | + .prompts |
| 345 | + .expect("updated settings should contain prompts"); |
| 346 | + assert_eq!(updated_prompts.system.as_deref(), prompts.system.as_deref()); |
| 347 | + assert_eq!( |
| 348 | + updated_prompts.search_description.as_deref(), |
| 349 | + prompts.search_description.as_deref() |
| 350 | + ); |
| 351 | + if let Some(masked_key) = updated.api_key.as_ref() { |
| 352 | + assert_ne!( |
| 353 | + masked_key, "sk-test", |
| 354 | + "API key should not be returned in clear text" |
| 355 | + ); |
| 356 | + } |
| 357 | + |
| 358 | + let workspace_info = client.get_chat_workspace(&workspace).await?; |
| 359 | + assert_eq!(workspace_info.uid, workspace); |
| 360 | + |
| 361 | + let fetched_settings = client.get_chat_workspace_settings(&workspace).await?; |
| 362 | + assert_eq!(fetched_settings.source.as_deref(), Some("openAi")); |
| 363 | + let fetched_prompts = fetched_settings |
| 364 | + .prompts |
| 365 | + .expect("workspace should have prompts configured"); |
| 366 | + assert_eq!(fetched_prompts.system.as_deref(), prompts.system.as_deref()); |
| 367 | + assert_eq!( |
| 368 | + fetched_prompts.search_description.as_deref(), |
| 369 | + prompts.search_description.as_deref() |
| 370 | + ); |
| 371 | + |
| 372 | + let list = client.list_chat_workspaces().await?; |
| 373 | + assert!(list.results.iter().any(|w| w.uid == workspace)); |
| 374 | + |
| 375 | + let mut query = ChatWorkspacesQuery::new(&client); |
| 376 | + query.with_limit(1); |
| 377 | + let limited = query.execute().await?; |
| 378 | + assert_eq!(limited.limit, 1); |
| 379 | + |
| 380 | + let _ = client.reset_chat_workspace_settings(&workspace).await?; |
| 381 | + |
| 382 | + Ok(()) |
| 383 | + } |
| 384 | +} |
0 commit comments