Skip to content

feat: tool management #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ target
Cargo.lock
**/*.rs.bk
.DS_Store
.env

# directory used to store images
data
Expand Down
1 change: 1 addition & 0 deletions async-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ secrecy = { version = "0.10.3", features = ["serde"] }
bytes = "1.9.0"
eventsource-stream = "0.2.3"
tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false }
schemars = "0.8.22"

[dev-dependencies]
tokio-test = "0.4.4"
Expand Down
2 changes: 2 additions & 0 deletions async-openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ mod projects;
mod runs;
mod steps;
mod threads;
mod tools;
pub mod traits;
pub mod types;
mod uploads;
Expand Down Expand Up @@ -180,6 +181,7 @@ pub use projects::Projects;
pub use runs::Runs;
pub use steps::Steps;
pub use threads::Threads;
pub use tools::{Tool, ToolManager, ToolCallStreamManager};
pub use uploads::Uploads;
pub use users::Users;
pub use vector_store_file_batches::VectorStoreFileBatches;
Expand Down
244 changes: 244 additions & 0 deletions async-openai/src/tools.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
//! This module provides functionality for managing and executing tools in an async OpenAI context.
//! It defines traits and structures for tool management, execution, and streaming.
use std::{
collections::{hash_map::Entry, BTreeMap, HashMap},
future::Future,
pin::Pin,
sync::Arc,
};

use schemars::{schema_for, JsonSchema};
use serde::{Deserialize, Serialize};
use serde_json::json;

use crate::types::{
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
ChatCompletionRequestToolMessage, ChatCompletionTool, ChatCompletionToolType, FunctionCall,
FunctionCallStream, FunctionObject,
};

/// A trait defining the interface for tools that can be used with the OpenAI API.
/// Tools must implement this trait to be used with the ToolManager.
pub trait Tool: Send + Sync {
/// The type of arguments that the tool accepts.
type Args: JsonSchema + for<'a> Deserialize<'a> + Send + Sync;
/// The type of output that the tool produces.
type Output: Serialize + Send + Sync;
/// The type of error that the tool can return.
type Error: ToString + Send + Sync;

/// Returns the name of the tool.
fn name() -> String {
Self::Args::schema_name()
}

/// Returns an optional description of the tool.
fn description() -> Option<String> {
None
}

/// Returns an optional boolean indicating whether the tool should be strict about the arguments.
fn strict() -> Option<bool> {
None
}

/// Creates a ChatCompletionTool definition for the tool.
fn definition() -> ChatCompletionTool {
ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: Self::name(),
description: Self::description(),
parameters: Some(json!(schema_for!(Self::Args))),
strict: Self::strict(),
},
}
}

/// Executes the tool with the given arguments.
/// Returns a Future that resolves to either the tool's output or an error.
fn call(
&self,
args: Self::Args,
) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send;
}

/// A dynamic trait for tools that allows for runtime tool management.
/// This trait provides a way to work with tools without knowing their concrete types at compile time.
trait ToolDyn: Send + Sync {
/// Returns the tool's definition as a ChatCompletionTool.
fn definition(&self) -> ChatCompletionTool;

/// Executes the tool with the given JSON string arguments.
/// Returns a Future that resolves to either a JSON string output or an error string.
fn call(
&self,
args: String,
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>>;
}

// Implementation of ToolDyn for any type that implements Tool
impl<T: Tool> ToolDyn for T {
fn definition(&self) -> ChatCompletionTool {
T::definition()
}

fn call(
&self,
args: String,
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
let future = async move {
// Special handling for T::Args = () case
// If the tool doesn't require arguments (T::Args is unit type),
// we can safely ignore the provided arguments string
match serde_json::from_str::<T::Args>(&args)
.or_else(|e| serde_json::from_str::<T::Args>(&"null").map_err(|_| e))
{
Ok(args) => T::call(self, args)
.await
.map_err(|e| e.to_string())
.and_then(|output| {
serde_json::to_string(&output)
.map_err(|e| format!("Failed to serialize output: {}", e))
}),
Err(e) => Err(format!("Failed to parse arguments: {}", e)),
}
};
Box::pin(future)
}
}

/// A manager for tools that allows adding, retrieving, and executing tools.
#[derive(Default, Clone)]
pub struct ToolManager {
/// A map of tool names to their dynamic implementations.
tools: BTreeMap<String, Arc<dyn ToolDyn>>,
}

impl ToolManager {
/// Creates a new ToolManager.
pub fn new() -> Self {
Self {
tools: BTreeMap::new(),
}
}

/// Adds a new tool to the manager.
pub fn add_tool<T: Tool + 'static>(&mut self, tool: T) {
self.tools.insert(T::name(), Arc::new(tool));
}

/// Removes a tool from the manager.
pub fn remove_tool(&mut self, name: &str) -> bool {
self.tools.remove(name).is_some()
}

/// Returns the definitions of all tools in the manager.
pub fn get_tools(&self) -> Vec<ChatCompletionTool> {
self.tools.values().map(|tool| tool.definition()).collect()
}

/// Executes multiple tool calls concurrently and returns their results.
pub async fn call(
&self,
calls: impl IntoIterator<Item = ChatCompletionMessageToolCall>,
) -> Vec<ChatCompletionRequestToolMessage> {
let mut handles = Vec::new();
let mut outputs = Vec::new();

// Spawn a task for each tool call
for call in calls {
if let Some(tool) = self.tools.get(&call.function.name).cloned() {
let handle = tokio::spawn(async move { tool.call(call.function.arguments).await });
handles.push((call.id, handle));
} else {
outputs.push(ChatCompletionRequestToolMessage {
content: "Tool call failed: tool not found".into(),
tool_call_id: call.id,
});
}
}

// Collect results from all spawned tasks
for (id, handle) in handles {
let output = match handle.await {
Ok(Ok(output)) => output,
Ok(Err(e)) => {
format!("Tool call failed: {}", e)
}
Err(_) => {
format!("Tool call failed: runtime error")
}
};
outputs.push(ChatCompletionRequestToolMessage {
content: output.into(),
tool_call_id: id,
});
}
outputs
}
}

/// A manager for handling streaming tool calls.
/// This structure helps manage and merge tool call chunks that arrive in a streaming fashion.
#[derive(Default, Clone, Debug)]
pub struct ToolCallStreamManager(HashMap<u32, ChatCompletionMessageToolCall>);

impl ToolCallStreamManager {
/// Creates a new empty ToolCallStreamManager.
pub fn new() -> Self {
Self(HashMap::new())
}

/// Processes a single streaming tool call chunk and merges it with existing data.
pub fn process_chunk(&mut self, chunk: ChatCompletionMessageToolCallChunk) {
match self.0.entry(chunk.index) {
Entry::Occupied(mut o) => {
if let Some(FunctionCallStream {
name: _,
arguments: Some(arguments),
}) = chunk.function
{
o.get_mut().function.arguments.push_str(&arguments);
}
}
Entry::Vacant(o) => {
let ChatCompletionMessageToolCallChunk {
index: _,
id: Some(id),
r#type: _,
function:
Some(FunctionCallStream {
name: Some(name),
arguments: Some(arguments),
}),
} = chunk
else {
tracing::error!("Tool call chunk is not complete: {:?}", chunk);
return;
};
let tool_call = ChatCompletionMessageToolCall {
id,
r#type: ChatCompletionToolType::Function,
function: FunctionCall { name, arguments },
};
o.insert(tool_call);
}
}
}

/// Processes multiple streaming tool call chunks and merges them with existing data.
pub fn process_chunks(
&mut self,
chunks: impl IntoIterator<Item = ChatCompletionMessageToolCallChunk>,
) {
for chunk in chunks {
self.process_chunk(chunk);
}
}

/// Returns all completed tool calls as a vector.
pub fn finish_stream(self) -> Vec<ChatCompletionMessageToolCall> {
self.0.into_values().collect()
}
}
4 changes: 3 additions & 1 deletion examples/tool-call-stream/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ publish = false
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-openai = {path = "../../async-openai"}
async-openai = { path = "../../async-openai" }
rand = "0.8.5"
serde = "1.0"
serde_json = "1.0.135"
tokio = { version = "1.43.0", features = ["full"] }
futures = "0.3.31"
schemars = "0.8.22"
Loading