From 7796784ff2d25b0ba8843a72fab28fdce59f5e26 Mon Sep 17 00:00:00 2001 From: Maverick Liu Date: Thu, 24 Apr 2025 17:58:31 +0800 Subject: [PATCH] feat: tool management --- .gitignore | 1 + async-openai/Cargo.toml | 1 + async-openai/src/lib.rs | 2 + async-openai/src/tools.rs | 244 ++++++++++++++++++++ examples/tool-call-stream/Cargo.toml | 4 +- examples/tool-call-stream/src/main.rs | 308 ++++++++++---------------- examples/tool-call/Cargo.toml | 4 +- examples/tool-call/src/main.rs | 163 ++++++-------- 8 files changed, 437 insertions(+), 290 deletions(-) create mode 100644 async-openai/src/tools.rs diff --git a/.gitignore b/.gitignore index 0d4aa54e..60a77793 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ target Cargo.lock **/*.rs.bk .DS_Store +.env # directory used to store images data diff --git a/async-openai/Cargo.toml b/async-openai/Cargo.toml index 578c25e8..07a27c4c 100644 --- a/async-openai/Cargo.toml +++ b/async-openai/Cargo.toml @@ -50,6 +50,7 @@ secrecy = { version = "0.10.3", features = ["serde"] } bytes = "1.9.0" eventsource-stream = "0.2.3" tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false } +schemars = "0.8.22" [dev-dependencies] tokio-test = "0.4.4" diff --git a/async-openai/src/lib.rs b/async-openai/src/lib.rs index 182e58ae..4aff409f 100644 --- a/async-openai/src/lib.rs +++ b/async-openai/src/lib.rs @@ -149,6 +149,7 @@ mod projects; mod runs; mod steps; mod threads; +mod tools; pub mod traits; pub mod types; mod uploads; @@ -180,6 +181,7 @@ pub use projects::Projects; pub use runs::Runs; pub use steps::Steps; pub use threads::Threads; +pub use tools::{Tool, ToolManager, ToolCallStreamManager}; pub use uploads::Uploads; pub use users::Users; pub use vector_store_file_batches::VectorStoreFileBatches; diff --git a/async-openai/src/tools.rs b/async-openai/src/tools.rs new file mode 100644 index 00000000..6ea1d78f --- /dev/null +++ b/async-openai/src/tools.rs @@ -0,0 +1,244 @@ +//! This module provides functionality for managing and executing tools in an async OpenAI context. +//! It defines traits and structures for tool management, execution, and streaming. +use std::{ + collections::{hash_map::Entry, BTreeMap, HashMap}, + future::Future, + pin::Pin, + sync::Arc, +}; + +use schemars::{schema_for, JsonSchema}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use crate::types::{ + ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, + ChatCompletionRequestToolMessage, ChatCompletionTool, ChatCompletionToolType, FunctionCall, + FunctionCallStream, FunctionObject, +}; + +/// A trait defining the interface for tools that can be used with the OpenAI API. +/// Tools must implement this trait to be used with the ToolManager. +pub trait Tool: Send + Sync { + /// The type of arguments that the tool accepts. + type Args: JsonSchema + for<'a> Deserialize<'a> + Send + Sync; + /// The type of output that the tool produces. + type Output: Serialize + Send + Sync; + /// The type of error that the tool can return. + type Error: ToString + Send + Sync; + + /// Returns the name of the tool. + fn name() -> String { + Self::Args::schema_name() + } + + /// Returns an optional description of the tool. + fn description() -> Option { + None + } + + /// Returns an optional boolean indicating whether the tool should be strict about the arguments. + fn strict() -> Option { + None + } + + /// Creates a ChatCompletionTool definition for the tool. + fn definition() -> ChatCompletionTool { + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: Self::name(), + description: Self::description(), + parameters: Some(json!(schema_for!(Self::Args))), + strict: Self::strict(), + }, + } + } + + /// Executes the tool with the given arguments. + /// Returns a Future that resolves to either the tool's output or an error. + fn call( + &self, + args: Self::Args, + ) -> impl Future> + Send; +} + +/// A dynamic trait for tools that allows for runtime tool management. +/// This trait provides a way to work with tools without knowing their concrete types at compile time. +trait ToolDyn: Send + Sync { + /// Returns the tool's definition as a ChatCompletionTool. + fn definition(&self) -> ChatCompletionTool; + + /// Executes the tool with the given JSON string arguments. + /// Returns a Future that resolves to either a JSON string output or an error string. + fn call( + &self, + args: String, + ) -> Pin> + Send + '_>>; +} + +// Implementation of ToolDyn for any type that implements Tool +impl ToolDyn for T { + fn definition(&self) -> ChatCompletionTool { + T::definition() + } + + fn call( + &self, + args: String, + ) -> Pin> + Send + '_>> { + let future = async move { + // Special handling for T::Args = () case + // If the tool doesn't require arguments (T::Args is unit type), + // we can safely ignore the provided arguments string + match serde_json::from_str::(&args) + .or_else(|e| serde_json::from_str::(&"null").map_err(|_| e)) + { + Ok(args) => T::call(self, args) + .await + .map_err(|e| e.to_string()) + .and_then(|output| { + serde_json::to_string(&output) + .map_err(|e| format!("Failed to serialize output: {}", e)) + }), + Err(e) => Err(format!("Failed to parse arguments: {}", e)), + } + }; + Box::pin(future) + } +} + +/// A manager for tools that allows adding, retrieving, and executing tools. +#[derive(Default, Clone)] +pub struct ToolManager { + /// A map of tool names to their dynamic implementations. + tools: BTreeMap>, +} + +impl ToolManager { + /// Creates a new ToolManager. + pub fn new() -> Self { + Self { + tools: BTreeMap::new(), + } + } + + /// Adds a new tool to the manager. + pub fn add_tool(&mut self, tool: T) { + self.tools.insert(T::name(), Arc::new(tool)); + } + + /// Removes a tool from the manager. + pub fn remove_tool(&mut self, name: &str) -> bool { + self.tools.remove(name).is_some() + } + + /// Returns the definitions of all tools in the manager. + pub fn get_tools(&self) -> Vec { + self.tools.values().map(|tool| tool.definition()).collect() + } + + /// Executes multiple tool calls concurrently and returns their results. + pub async fn call( + &self, + calls: impl IntoIterator, + ) -> Vec { + let mut handles = Vec::new(); + let mut outputs = Vec::new(); + + // Spawn a task for each tool call + for call in calls { + if let Some(tool) = self.tools.get(&call.function.name).cloned() { + let handle = tokio::spawn(async move { tool.call(call.function.arguments).await }); + handles.push((call.id, handle)); + } else { + outputs.push(ChatCompletionRequestToolMessage { + content: "Tool call failed: tool not found".into(), + tool_call_id: call.id, + }); + } + } + + // Collect results from all spawned tasks + for (id, handle) in handles { + let output = match handle.await { + Ok(Ok(output)) => output, + Ok(Err(e)) => { + format!("Tool call failed: {}", e) + } + Err(_) => { + format!("Tool call failed: runtime error") + } + }; + outputs.push(ChatCompletionRequestToolMessage { + content: output.into(), + tool_call_id: id, + }); + } + outputs + } +} + +/// A manager for handling streaming tool calls. +/// This structure helps manage and merge tool call chunks that arrive in a streaming fashion. +#[derive(Default, Clone, Debug)] +pub struct ToolCallStreamManager(HashMap); + +impl ToolCallStreamManager { + /// Creates a new empty ToolCallStreamManager. + pub fn new() -> Self { + Self(HashMap::new()) + } + + /// Processes a single streaming tool call chunk and merges it with existing data. + pub fn process_chunk(&mut self, chunk: ChatCompletionMessageToolCallChunk) { + match self.0.entry(chunk.index) { + Entry::Occupied(mut o) => { + if let Some(FunctionCallStream { + name: _, + arguments: Some(arguments), + }) = chunk.function + { + o.get_mut().function.arguments.push_str(&arguments); + } + } + Entry::Vacant(o) => { + let ChatCompletionMessageToolCallChunk { + index: _, + id: Some(id), + r#type: _, + function: + Some(FunctionCallStream { + name: Some(name), + arguments: Some(arguments), + }), + } = chunk + else { + tracing::error!("Tool call chunk is not complete: {:?}", chunk); + return; + }; + let tool_call = ChatCompletionMessageToolCall { + id, + r#type: ChatCompletionToolType::Function, + function: FunctionCall { name, arguments }, + }; + o.insert(tool_call); + } + } + } + + /// Processes multiple streaming tool call chunks and merges them with existing data. + pub fn process_chunks( + &mut self, + chunks: impl IntoIterator, + ) { + for chunk in chunks { + self.process_chunk(chunk); + } + } + + /// Returns all completed tool calls as a vector. + pub fn finish_stream(self) -> Vec { + self.0.into_values().collect() + } +} diff --git a/examples/tool-call-stream/Cargo.toml b/examples/tool-call-stream/Cargo.toml index 6d68ba36..0fc20eab 100644 --- a/examples/tool-call-stream/Cargo.toml +++ b/examples/tool-call-stream/Cargo.toml @@ -7,8 +7,10 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -async-openai = {path = "../../async-openai"} +async-openai = { path = "../../async-openai" } rand = "0.8.5" +serde = "1.0" serde_json = "1.0.135" tokio = { version = "1.43.0", features = ["full"] } futures = "0.3.31" +schemars = "0.8.22" diff --git a/examples/tool-call-stream/src/main.rs b/examples/tool-call-stream/src/main.rs index 230ee9a3..fcf9d17c 100644 --- a/examples/tool-call-stream/src/main.rs +++ b/examples/tool-call-stream/src/main.rs @@ -1,200 +1,52 @@ -use std::collections::HashMap; use std::error::Error; use std::io::{stdout, Write}; -use std::sync::Arc; use async_openai::types::{ - ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs, - ChatCompletionRequestMessage, ChatCompletionRequestToolMessageArgs, - ChatCompletionRequestUserMessageArgs, ChatCompletionToolArgs, ChatCompletionToolType, - FinishReason, FunctionCall, FunctionObjectArgs, + ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, + ChatCompletionRequestUserMessageArgs, FinishReason, }; use async_openai::{types::CreateChatCompletionRequestArgs, Client}; +use async_openai::{Tool, ToolCallStreamManager, ToolManager}; use futures::StreamExt; use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; -use serde_json::{json, Value}; -use tokio::sync::Mutex; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; #[tokio::main] async fn main() -> Result<(), Box> { let client = Client::new(); - let user_prompt = "What's the weather like in Boston and Atlanta?"; + let mut messages = vec![ChatCompletionRequestUserMessageArgs::default() + .content("What's the weather like in Boston and Atlanta?") + .build()? + .into()]; + let weather_tool = WeatherTool; + let mut tool_manager = ToolManager::new(); + tool_manager.add_tool(weather_tool); + let tools = tool_manager.get_tools(); let request = CreateChatCompletionRequestArgs::default() .max_tokens(512u32) .model("gpt-4-1106-preview") - .messages([ChatCompletionRequestUserMessageArgs::default() - .content(user_prompt) - .build()? - .into()]) - .tools(vec![ChatCompletionToolArgs::default() - .r#type(ChatCompletionToolType::Function) - .function( - FunctionObjectArgs::default() - .name("get_current_weather") - .description("Get the current weather in a given location") - .parameters(json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { "type": "string", "enum": ["celsius", "fahrenheit"] }, - }, - "required": ["location"], - })) - .build()?, - ) - .build()?]) + .messages(messages.clone()) + .tools(tools) .build()?; let mut stream = client.chat().create_stream(request).await?; - let tool_call_states: Arc>> = - Arc::new(Mutex::new(HashMap::new())); + let mut tool_call_stream_manager = ToolCallStreamManager::new(); + let mut is_end_with_tool_call = false; while let Some(result) = stream.next().await { match result { Ok(response) => { for chat_choice in response.choices { - let function_responses: Arc< - Mutex>, - > = Arc::new(Mutex::new(Vec::new())); - if let Some(tool_calls) = chat_choice.delta.tool_calls { - for tool_call_chunk in tool_calls.into_iter() { - let key = (chat_choice.index, tool_call_chunk.index); - let states = tool_call_states.clone(); - let tool_call_data = tool_call_chunk.clone(); - - let mut states_lock = states.lock().await; - let state = states_lock.entry(key).or_insert_with(|| { - ChatCompletionMessageToolCall { - id: tool_call_data.id.clone().unwrap_or_default(), - r#type: ChatCompletionToolType::Function, - function: FunctionCall { - name: tool_call_data - .function - .as_ref() - .and_then(|f| f.name.clone()) - .unwrap_or_default(), - arguments: tool_call_data - .function - .as_ref() - .and_then(|f| f.arguments.clone()) - .unwrap_or_default(), - }, - } - }); - if let Some(arguments) = tool_call_chunk - .function - .as_ref() - .and_then(|f| f.arguments.as_ref()) - { - state.function.arguments.push_str(arguments); - } - } + if let Some(tool_call_chunks) = chat_choice.delta.tool_calls { + tool_call_stream_manager.process_chunks(tool_call_chunks); } if let Some(finish_reason) = &chat_choice.finish_reason { if matches!(finish_reason, FinishReason::ToolCalls) { - let tool_call_states_clone = tool_call_states.clone(); - - let tool_calls_to_process = { - let states_lock = tool_call_states_clone.lock().await; - states_lock - .iter() - .map(|(_key, tool_call)| { - let name = tool_call.function.name.clone(); - let args = tool_call.function.arguments.clone(); - let tool_call_clone = tool_call.clone(); - (name, args, tool_call_clone) - }) - .collect::>() - }; - - let mut handles = Vec::new(); - for (name, args, tool_call_clone) in tool_calls_to_process { - let response_content_clone = function_responses.clone(); - let handle = tokio::spawn(async move { - let response_content = call_fn(&name, &args).await.unwrap(); - let mut function_responses_lock = - response_content_clone.lock().await; - function_responses_lock - .push((tool_call_clone, response_content)); - }); - handles.push(handle); - } - - for handle in handles { - handle.await.unwrap(); - } - - let function_responses_clone = function_responses.clone(); - let function_responses_lock = function_responses_clone.lock().await; - let mut messages: Vec = - vec![ChatCompletionRequestUserMessageArgs::default() - .content(user_prompt) - .build()? - .into()]; - - let tool_calls: Vec = - function_responses_lock - .iter() - .map(|tc| tc.0.clone()) - .collect(); - - let assistant_messages: ChatCompletionRequestMessage = - ChatCompletionRequestAssistantMessageArgs::default() - .tool_calls(tool_calls) - .build() - .map_err(|e| Box::new(e) as Box) - .unwrap() - .into(); - - let tool_messages: Vec = - function_responses_lock - .iter() - .map(|tc| { - ChatCompletionRequestToolMessageArgs::default() - .content(tc.1.to_string()) - .tool_call_id(tc.0.id.clone()) - .build() - .map_err(|e| Box::new(e) as Box) - .unwrap() - .into() - }) - .collect(); - - messages.push(assistant_messages); - messages.extend(tool_messages); - - let request = CreateChatCompletionRequestArgs::default() - .max_tokens(512u32) - .model("gpt-4-1106-preview") - .messages(messages) - .build() - .map_err(|e| Box::new(e) as Box)?; - - let mut stream = client.chat().create_stream(request).await?; - - let mut response_content = String::new(); - let mut lock = stdout().lock(); - while let Some(result) = stream.next().await { - match result { - Ok(response) => { - for chat_choice in response.choices.iter() { - if let Some(ref content) = chat_choice.delta.content { - write!(lock, "{}", content).unwrap(); - response_content.push_str(content); - } - } - } - Err(err) => { - return Err(Box::new(err) as Box); - } - } - } + is_end_with_tool_call = true; } } @@ -214,40 +66,112 @@ async fn main() -> Result<(), Box> { .map_err(|e| Box::new(e) as Box)?; } + if !is_end_with_tool_call { + return Err("The response is not ended with tool call".into()); + } + let tool_calls = tool_call_stream_manager.finish_stream(); + let function_responses = tool_manager.call(tool_calls.clone()).await; + + let assistant_messages: ChatCompletionRequestMessage = + ChatCompletionRequestAssistantMessageArgs::default() + .tool_calls(tool_calls) + .build() + .map_err(|e| Box::new(e) as Box) + .unwrap() + .into(); + + let tool_messages: Vec = function_responses + .into_iter() + .map(|res| res.into()) + .collect(); + + messages.push(assistant_messages); + messages.extend(tool_messages); + + let request = CreateChatCompletionRequestArgs::default() + .max_tokens(512u32) + .model("gpt-4-1106-preview") + .messages(messages) + .build() + .map_err(|e| Box::new(e) as Box)?; + + let mut stream = client.chat().create_stream(request).await?; + + let mut response_content = String::new(); + let mut lock = stdout().lock(); + while let Some(result) = stream.next().await { + match result { + Ok(response) => { + for chat_choice in response.choices.iter() { + if let Some(ref content) = chat_choice.delta.content { + write!(lock, "{}", content).unwrap(); + response_content.push_str(content); + } + } + } + Err(err) => { + return Err(Box::new(err) as Box); + } + } + } + Ok(()) } -async fn call_fn(name: &str, args: &str) -> Result> { - let mut available_functions: HashMap<&str, fn(&str, &str) -> serde_json::Value> = - HashMap::new(); - available_functions.insert("get_current_weather", get_current_weather); +#[derive(Debug, JsonSchema, Deserialize, Serialize)] +enum Unit { + Fahrenheit, + Celsius, +} - let function_args: serde_json::Value = args.parse().unwrap(); +#[derive(Debug, JsonSchema, Deserialize)] +struct WeatherRequest { + /// The city and state, e.g. San Francisco, CA + location: String, + unit: Unit, +} - let location = function_args["location"].as_str().unwrap(); - let unit = function_args["unit"].as_str().unwrap_or("fahrenheit"); - let function = available_functions.get(name).unwrap(); - let function_response = function(location, unit); - Ok(function_response) +#[derive(Debug, Serialize)] +struct WeatherResponse { + location: String, + temperature: String, + unit: Unit, + forecast: String, } -fn get_current_weather(location: &str, unit: &str) -> serde_json::Value { - let mut rng = thread_rng(); +struct WeatherTool; + +impl Tool for WeatherTool { + type Args = WeatherRequest; + type Output = WeatherResponse; + type Error = String; + + fn name() -> String { + "get_current_weather".to_string() + } + + fn description() -> Option { + Some("Get the current weather in a given location".to_string()) + } + + async fn call(&self, args: Self::Args) -> Result { + let mut rng = thread_rng(); - let temperature: i32 = rng.gen_range(20..=55); + let temperature: i32 = rng.gen_range(20..=55); - let forecasts = [ - "sunny", "cloudy", "overcast", "rainy", "windy", "foggy", "snowy", - ]; + let forecasts = [ + "sunny", "cloudy", "overcast", "rainy", "windy", "foggy", "snowy", + ]; - let forecast = forecasts.choose(&mut rng).unwrap_or(&"sunny"); + let forecast = forecasts.choose(&mut rng).unwrap_or(&"sunny"); - let weather_info = json!({ - "location": location, - "temperature": temperature.to_string(), - "unit": unit, - "forecast": forecast - }); + let weather_info = WeatherResponse { + location: args.location, + temperature: temperature.to_string(), + unit: args.unit, + forecast: forecast.to_string(), + }; - weather_info + Ok(weather_info) + } } diff --git a/examples/tool-call/Cargo.toml b/examples/tool-call/Cargo.toml index e6a2dc63..9cdbd89c 100644 --- a/examples/tool-call/Cargo.toml +++ b/examples/tool-call/Cargo.toml @@ -7,8 +7,10 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -async-openai = {path = "../../async-openai"} +async-openai = { path = "../../async-openai" } rand = "0.8.5" +serde = "1.0" serde_json = "1.0.135" tokio = { version = "1.43.0", features = ["full"] } futures = "0.3.31" +schemars = "0.8.22" diff --git a/examples/tool-call/src/main.rs b/examples/tool-call/src/main.rs index c88fa2fa..ae32f1de 100644 --- a/examples/tool-call/src/main.rs +++ b/examples/tool-call/src/main.rs @@ -1,50 +1,34 @@ -use std::collections::HashMap; use std::io::{stdout, Write}; use async_openai::types::{ - ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs, - ChatCompletionRequestMessage, ChatCompletionRequestToolMessageArgs, - ChatCompletionRequestUserMessageArgs, ChatCompletionToolArgs, ChatCompletionToolType, - FunctionObjectArgs, + ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, + ChatCompletionRequestUserMessageArgs, }; use async_openai::{types::CreateChatCompletionRequestArgs, Client}; +use async_openai::{Tool, ToolManager}; use futures::StreamExt; use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; -use serde_json::{json, Value}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; #[tokio::main] async fn main() -> Result<(), Box> { let client = Client::new(); - let user_prompt = "What's the weather like in Boston and Atlanta?"; - + let mut messages = vec![ChatCompletionRequestUserMessageArgs::default() + .content("What's the weather like in Boston and Atlanta?") + .build()? + .into()]; + + let weather_tool = WeatherTool; + let mut tool_manager = ToolManager::new(); + tool_manager.add_tool(weather_tool); + let tools = tool_manager.get_tools(); let request = CreateChatCompletionRequestArgs::default() .max_tokens(512u32) .model("gpt-4-1106-preview") - .messages([ChatCompletionRequestUserMessageArgs::default() - .content(user_prompt) - .build()? - .into()]) - .tools(vec![ChatCompletionToolArgs::default() - .r#type(ChatCompletionToolType::Function) - .function( - FunctionObjectArgs::default() - .name("get_current_weather") - .description("Get the current weather in a given location") - .parameters(json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { "type": "string", "enum": ["celsius", "fahrenheit"] }, - }, - "required": ["location"], - })) - .build()?, - ) - .build()?]) + .messages(messages.clone()) + .tools(tools) .build()?; let response_message = client @@ -58,52 +42,16 @@ async fn main() -> Result<(), Box> { .clone(); if let Some(tool_calls) = response_message.tool_calls { - let mut handles = Vec::new(); - for tool_call in tool_calls { - let name = tool_call.function.name.clone(); - let args = tool_call.function.arguments.clone(); - let tool_call_clone = tool_call.clone(); - - let handle = - tokio::spawn(async move { call_fn(&name, &args).await.unwrap_or_default() }); - handles.push((handle, tool_call_clone)); - } - - let mut function_responses = Vec::new(); - - for (handle, tool_call_clone) in handles { - if let Ok(response_content) = handle.await { - function_responses.push((tool_call_clone, response_content)); - } - } - - let mut messages: Vec = - vec![ChatCompletionRequestUserMessageArgs::default() - .content(user_prompt) - .build()? - .into()]; - - let tool_calls: Vec = function_responses - .iter() - .map(|(tool_call, _response_content)| tool_call.clone()) - .collect(); - let assistant_messages: ChatCompletionRequestMessage = ChatCompletionRequestAssistantMessageArgs::default() - .tool_calls(tool_calls) + .tool_calls(tool_calls.clone()) .build()? .into(); + let function_responses = tool_manager.call(tool_calls.clone()).await; let tool_messages: Vec = function_responses - .iter() - .map(|(tool_call, response_content)| { - ChatCompletionRequestToolMessageArgs::default() - .content(response_content.to_string()) - .tool_call_id(tool_call.id.clone()) - .build() - .unwrap() - .into() - }) + .into_iter() + .map(|res| res.into()) .collect(); messages.push(assistant_messages); @@ -140,37 +88,60 @@ async fn main() -> Result<(), Box> { Ok(()) } -async fn call_fn(name: &str, args: &str) -> Result> { - let mut available_functions: HashMap<&str, fn(&str, &str) -> serde_json::Value> = - HashMap::new(); - available_functions.insert("get_current_weather", get_current_weather); +#[derive(Debug, JsonSchema, Deserialize, Serialize)] +enum Unit { + Fahrenheit, + Celsius, +} - let function_args: serde_json::Value = args.parse().unwrap(); +#[derive(Debug, JsonSchema, Deserialize)] +struct WeatherRequest { + /// The city and state, e.g. San Francisco, CA + location: String, + unit: Unit, +} - let location = function_args["location"].as_str().unwrap(); - let unit = function_args["unit"].as_str().unwrap_or("fahrenheit"); - let function = available_functions.get(name).unwrap(); - let function_response = function(location, unit); - Ok(function_response) +#[derive(Debug, Serialize)] +struct WeatherResponse { + location: String, + temperature: String, + unit: Unit, + forecast: String, } -fn get_current_weather(location: &str, unit: &str) -> serde_json::Value { - let mut rng = thread_rng(); +struct WeatherTool; + +impl Tool for WeatherTool { + type Args = WeatherRequest; + type Output = WeatherResponse; + type Error = String; + + fn name() -> String { + "get_current_weather".to_string() + } + + fn description() -> Option { + Some("Get the current weather in a given location".to_string()) + } - let temperature: i32 = rng.gen_range(20..=55); + async fn call(&self, args: Self::Args) -> Result { + let mut rng = thread_rng(); - let forecasts = [ - "sunny", "cloudy", "overcast", "rainy", "windy", "foggy", "snowy", - ]; + let temperature: i32 = rng.gen_range(20..=55); - let forecast = forecasts.choose(&mut rng).unwrap_or(&"sunny"); + let forecasts = [ + "sunny", "cloudy", "overcast", "rainy", "windy", "foggy", "snowy", + ]; - let weather_info = json!({ - "location": location, - "temperature": temperature.to_string(), - "unit": unit, - "forecast": forecast - }); + let forecast = forecasts.choose(&mut rng).unwrap_or(&"sunny"); - weather_info + let weather_info = WeatherResponse { + location: args.location, + temperature: temperature.to_string(), + unit: args.unit, + forecast: forecast.to_string(), + }; + + Ok(weather_info) + } }