Skip to content

Commit 7796784

Browse files
committed
feat: tool management
1 parent aeb6d1f commit 7796784

File tree

8 files changed

+437
-290
lines changed

8 files changed

+437
-290
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ target
22
Cargo.lock
33
**/*.rs.bk
44
.DS_Store
5+
.env
56

67
# directory used to store images
78
data

async-openai/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ secrecy = { version = "0.10.3", features = ["serde"] }
5050
bytes = "1.9.0"
5151
eventsource-stream = "0.2.3"
5252
tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false }
53+
schemars = "0.8.22"
5354

5455
[dev-dependencies]
5556
tokio-test = "0.4.4"

async-openai/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ mod projects;
149149
mod runs;
150150
mod steps;
151151
mod threads;
152+
mod tools;
152153
pub mod traits;
153154
pub mod types;
154155
mod uploads;
@@ -180,6 +181,7 @@ pub use projects::Projects;
180181
pub use runs::Runs;
181182
pub use steps::Steps;
182183
pub use threads::Threads;
184+
pub use tools::{Tool, ToolManager, ToolCallStreamManager};
183185
pub use uploads::Uploads;
184186
pub use users::Users;
185187
pub use vector_store_file_batches::VectorStoreFileBatches;

async-openai/src/tools.rs

+244
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
//! This module provides functionality for managing and executing tools in an async OpenAI context.
2+
//! It defines traits and structures for tool management, execution, and streaming.
3+
use std::{
4+
collections::{hash_map::Entry, BTreeMap, HashMap},
5+
future::Future,
6+
pin::Pin,
7+
sync::Arc,
8+
};
9+
10+
use schemars::{schema_for, JsonSchema};
11+
use serde::{Deserialize, Serialize};
12+
use serde_json::json;
13+
14+
use crate::types::{
15+
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
16+
ChatCompletionRequestToolMessage, ChatCompletionTool, ChatCompletionToolType, FunctionCall,
17+
FunctionCallStream, FunctionObject,
18+
};
19+
20+
/// A trait defining the interface for tools that can be used with the OpenAI API.
21+
/// Tools must implement this trait to be used with the ToolManager.
22+
pub trait Tool: Send + Sync {
23+
/// The type of arguments that the tool accepts.
24+
type Args: JsonSchema + for<'a> Deserialize<'a> + Send + Sync;
25+
/// The type of output that the tool produces.
26+
type Output: Serialize + Send + Sync;
27+
/// The type of error that the tool can return.
28+
type Error: ToString + Send + Sync;
29+
30+
/// Returns the name of the tool.
31+
fn name() -> String {
32+
Self::Args::schema_name()
33+
}
34+
35+
/// Returns an optional description of the tool.
36+
fn description() -> Option<String> {
37+
None
38+
}
39+
40+
/// Returns an optional boolean indicating whether the tool should be strict about the arguments.
41+
fn strict() -> Option<bool> {
42+
None
43+
}
44+
45+
/// Creates a ChatCompletionTool definition for the tool.
46+
fn definition() -> ChatCompletionTool {
47+
ChatCompletionTool {
48+
r#type: ChatCompletionToolType::Function,
49+
function: FunctionObject {
50+
name: Self::name(),
51+
description: Self::description(),
52+
parameters: Some(json!(schema_for!(Self::Args))),
53+
strict: Self::strict(),
54+
},
55+
}
56+
}
57+
58+
/// Executes the tool with the given arguments.
59+
/// Returns a Future that resolves to either the tool's output or an error.
60+
fn call(
61+
&self,
62+
args: Self::Args,
63+
) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send;
64+
}
65+
66+
/// A dynamic trait for tools that allows for runtime tool management.
67+
/// This trait provides a way to work with tools without knowing their concrete types at compile time.
68+
trait ToolDyn: Send + Sync {
69+
/// Returns the tool's definition as a ChatCompletionTool.
70+
fn definition(&self) -> ChatCompletionTool;
71+
72+
/// Executes the tool with the given JSON string arguments.
73+
/// Returns a Future that resolves to either a JSON string output or an error string.
74+
fn call(
75+
&self,
76+
args: String,
77+
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>>;
78+
}
79+
80+
// Implementation of ToolDyn for any type that implements Tool
81+
impl<T: Tool> ToolDyn for T {
82+
fn definition(&self) -> ChatCompletionTool {
83+
T::definition()
84+
}
85+
86+
fn call(
87+
&self,
88+
args: String,
89+
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
90+
let future = async move {
91+
// Special handling for T::Args = () case
92+
// If the tool doesn't require arguments (T::Args is unit type),
93+
// we can safely ignore the provided arguments string
94+
match serde_json::from_str::<T::Args>(&args)
95+
.or_else(|e| serde_json::from_str::<T::Args>(&"null").map_err(|_| e))
96+
{
97+
Ok(args) => T::call(self, args)
98+
.await
99+
.map_err(|e| e.to_string())
100+
.and_then(|output| {
101+
serde_json::to_string(&output)
102+
.map_err(|e| format!("Failed to serialize output: {}", e))
103+
}),
104+
Err(e) => Err(format!("Failed to parse arguments: {}", e)),
105+
}
106+
};
107+
Box::pin(future)
108+
}
109+
}
110+
111+
/// A manager for tools that allows adding, retrieving, and executing tools.
112+
#[derive(Default, Clone)]
113+
pub struct ToolManager {
114+
/// A map of tool names to their dynamic implementations.
115+
tools: BTreeMap<String, Arc<dyn ToolDyn>>,
116+
}
117+
118+
impl ToolManager {
119+
/// Creates a new ToolManager.
120+
pub fn new() -> Self {
121+
Self {
122+
tools: BTreeMap::new(),
123+
}
124+
}
125+
126+
/// Adds a new tool to the manager.
127+
pub fn add_tool<T: Tool + 'static>(&mut self, tool: T) {
128+
self.tools.insert(T::name(), Arc::new(tool));
129+
}
130+
131+
/// Removes a tool from the manager.
132+
pub fn remove_tool(&mut self, name: &str) -> bool {
133+
self.tools.remove(name).is_some()
134+
}
135+
136+
/// Returns the definitions of all tools in the manager.
137+
pub fn get_tools(&self) -> Vec<ChatCompletionTool> {
138+
self.tools.values().map(|tool| tool.definition()).collect()
139+
}
140+
141+
/// Executes multiple tool calls concurrently and returns their results.
142+
pub async fn call(
143+
&self,
144+
calls: impl IntoIterator<Item = ChatCompletionMessageToolCall>,
145+
) -> Vec<ChatCompletionRequestToolMessage> {
146+
let mut handles = Vec::new();
147+
let mut outputs = Vec::new();
148+
149+
// Spawn a task for each tool call
150+
for call in calls {
151+
if let Some(tool) = self.tools.get(&call.function.name).cloned() {
152+
let handle = tokio::spawn(async move { tool.call(call.function.arguments).await });
153+
handles.push((call.id, handle));
154+
} else {
155+
outputs.push(ChatCompletionRequestToolMessage {
156+
content: "Tool call failed: tool not found".into(),
157+
tool_call_id: call.id,
158+
});
159+
}
160+
}
161+
162+
// Collect results from all spawned tasks
163+
for (id, handle) in handles {
164+
let output = match handle.await {
165+
Ok(Ok(output)) => output,
166+
Ok(Err(e)) => {
167+
format!("Tool call failed: {}", e)
168+
}
169+
Err(_) => {
170+
format!("Tool call failed: runtime error")
171+
}
172+
};
173+
outputs.push(ChatCompletionRequestToolMessage {
174+
content: output.into(),
175+
tool_call_id: id,
176+
});
177+
}
178+
outputs
179+
}
180+
}
181+
182+
/// A manager for handling streaming tool calls.
183+
/// This structure helps manage and merge tool call chunks that arrive in a streaming fashion.
184+
#[derive(Default, Clone, Debug)]
185+
pub struct ToolCallStreamManager(HashMap<u32, ChatCompletionMessageToolCall>);
186+
187+
impl ToolCallStreamManager {
188+
/// Creates a new empty ToolCallStreamManager.
189+
pub fn new() -> Self {
190+
Self(HashMap::new())
191+
}
192+
193+
/// Processes a single streaming tool call chunk and merges it with existing data.
194+
pub fn process_chunk(&mut self, chunk: ChatCompletionMessageToolCallChunk) {
195+
match self.0.entry(chunk.index) {
196+
Entry::Occupied(mut o) => {
197+
if let Some(FunctionCallStream {
198+
name: _,
199+
arguments: Some(arguments),
200+
}) = chunk.function
201+
{
202+
o.get_mut().function.arguments.push_str(&arguments);
203+
}
204+
}
205+
Entry::Vacant(o) => {
206+
let ChatCompletionMessageToolCallChunk {
207+
index: _,
208+
id: Some(id),
209+
r#type: _,
210+
function:
211+
Some(FunctionCallStream {
212+
name: Some(name),
213+
arguments: Some(arguments),
214+
}),
215+
} = chunk
216+
else {
217+
tracing::error!("Tool call chunk is not complete: {:?}", chunk);
218+
return;
219+
};
220+
let tool_call = ChatCompletionMessageToolCall {
221+
id,
222+
r#type: ChatCompletionToolType::Function,
223+
function: FunctionCall { name, arguments },
224+
};
225+
o.insert(tool_call);
226+
}
227+
}
228+
}
229+
230+
/// Processes multiple streaming tool call chunks and merges them with existing data.
231+
pub fn process_chunks(
232+
&mut self,
233+
chunks: impl IntoIterator<Item = ChatCompletionMessageToolCallChunk>,
234+
) {
235+
for chunk in chunks {
236+
self.process_chunk(chunk);
237+
}
238+
}
239+
240+
/// Returns all completed tool calls as a vector.
241+
pub fn finish_stream(self) -> Vec<ChatCompletionMessageToolCall> {
242+
self.0.into_values().collect()
243+
}
244+
}

examples/tool-call-stream/Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ publish = false
77
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
88

99
[dependencies]
10-
async-openai = {path = "../../async-openai"}
10+
async-openai = { path = "../../async-openai" }
1111
rand = "0.8.5"
12+
serde = "1.0"
1213
serde_json = "1.0.135"
1314
tokio = { version = "1.43.0", features = ["full"] }
1415
futures = "0.3.31"
16+
schemars = "0.8.22"

0 commit comments

Comments
 (0)