Skip to content

Commit a630743

Browse files
committed
update tool call examples
1 parent 47b5d0f commit a630743

File tree

5 files changed

+196
-290
lines changed

5 files changed

+196
-290
lines changed

async-openai/src/tools.rs

+7
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ pub struct ToolManager {
121121
}
122122

123123
impl ToolManager {
124+
/// Creates a new ToolManager.
125+
pub fn new() -> Self {
126+
Self {
127+
tools: BTreeMap::new(),
128+
}
129+
}
130+
124131
/// Adds a new tool to the manager.
125132
pub fn add_tool(&mut self, tool: impl Tool + 'static) {
126133
self.tools.insert(tool.name(), Arc::new(tool));

examples/tool-call-stream/Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ publish = false
77
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
88

99
[dependencies]
10-
async-openai = {path = "../../async-openai"}
10+
async-openai = { path = "../../async-openai" }
1111
rand = "0.8.5"
12+
serde = "1.0"
1213
serde_json = "1.0.135"
1314
tokio = { version = "1.43.0", features = ["full"] }
1415
futures = "0.3.31"
16+
schemars = "0.8.22"

examples/tool-call-stream/src/main.rs

+116-192
Original file line numberDiff line numberDiff line change
@@ -1,200 +1,52 @@
1-
use std::collections::HashMap;
21
use std::error::Error;
32
use std::io::{stdout, Write};
4-
use std::sync::Arc;
53

64
use async_openai::types::{
7-
ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
8-
ChatCompletionRequestMessage, ChatCompletionRequestToolMessageArgs,
9-
ChatCompletionRequestUserMessageArgs, ChatCompletionToolArgs, ChatCompletionToolType,
10-
FinishReason, FunctionCall, FunctionObjectArgs,
5+
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
6+
ChatCompletionRequestUserMessageArgs, FinishReason,
117
};
128
use async_openai::{types::CreateChatCompletionRequestArgs, Client};
9+
use async_openai::{Tool, ToolCallStreamManager, ToolManager};
1310
use futures::StreamExt;
1411
use rand::seq::SliceRandom;
1512
use rand::{thread_rng, Rng};
16-
use serde_json::{json, Value};
17-
use tokio::sync::Mutex;
13+
use schemars::JsonSchema;
14+
use serde::{Deserialize, Serialize};
1815

1916
#[tokio::main]
2017
async fn main() -> Result<(), Box<dyn std::error::Error>> {
2118
let client = Client::new();
22-
let user_prompt = "What's the weather like in Boston and Atlanta?";
19+
let mut messages = vec![ChatCompletionRequestUserMessageArgs::default()
20+
.content("What's the weather like in Boston and Atlanta?")
21+
.build()?
22+
.into()];
23+
let weather_tool = WeatherTool;
24+
let mut tool_manager = ToolManager::new();
25+
tool_manager.add_tool(weather_tool);
26+
let tools = tool_manager.get_tools();
2327

2428
let request = CreateChatCompletionRequestArgs::default()
2529
.max_tokens(512u32)
2630
.model("gpt-4-1106-preview")
27-
.messages([ChatCompletionRequestUserMessageArgs::default()
28-
.content(user_prompt)
29-
.build()?
30-
.into()])
31-
.tools(vec![ChatCompletionToolArgs::default()
32-
.r#type(ChatCompletionToolType::Function)
33-
.function(
34-
FunctionObjectArgs::default()
35-
.name("get_current_weather")
36-
.description("Get the current weather in a given location")
37-
.parameters(json!({
38-
"type": "object",
39-
"properties": {
40-
"location": {
41-
"type": "string",
42-
"description": "The city and state, e.g. San Francisco, CA",
43-
},
44-
"unit": { "type": "string", "enum": ["celsius", "fahrenheit"] },
45-
},
46-
"required": ["location"],
47-
}))
48-
.build()?,
49-
)
50-
.build()?])
31+
.messages(messages.clone())
32+
.tools(tools)
5133
.build()?;
5234

5335
let mut stream = client.chat().create_stream(request).await?;
5436

55-
let tool_call_states: Arc<Mutex<HashMap<(u32, u32), ChatCompletionMessageToolCall>>> =
56-
Arc::new(Mutex::new(HashMap::new()));
37+
let mut tool_call_stream_manager = ToolCallStreamManager::new();
5738

39+
let mut is_end_with_tool_call = false;
5840
while let Some(result) = stream.next().await {
5941
match result {
6042
Ok(response) => {
6143
for chat_choice in response.choices {
62-
let function_responses: Arc<
63-
Mutex<Vec<(ChatCompletionMessageToolCall, Value)>>,
64-
> = Arc::new(Mutex::new(Vec::new()));
65-
if let Some(tool_calls) = chat_choice.delta.tool_calls {
66-
for tool_call_chunk in tool_calls.into_iter() {
67-
let key = (chat_choice.index, tool_call_chunk.index);
68-
let states = tool_call_states.clone();
69-
let tool_call_data = tool_call_chunk.clone();
70-
71-
let mut states_lock = states.lock().await;
72-
let state = states_lock.entry(key).or_insert_with(|| {
73-
ChatCompletionMessageToolCall {
74-
id: tool_call_data.id.clone().unwrap_or_default(),
75-
r#type: ChatCompletionToolType::Function,
76-
function: FunctionCall {
77-
name: tool_call_data
78-
.function
79-
.as_ref()
80-
.and_then(|f| f.name.clone())
81-
.unwrap_or_default(),
82-
arguments: tool_call_data
83-
.function
84-
.as_ref()
85-
.and_then(|f| f.arguments.clone())
86-
.unwrap_or_default(),
87-
},
88-
}
89-
});
90-
if let Some(arguments) = tool_call_chunk
91-
.function
92-
.as_ref()
93-
.and_then(|f| f.arguments.as_ref())
94-
{
95-
state.function.arguments.push_str(arguments);
96-
}
97-
}
44+
if let Some(tool_call_chunks) = chat_choice.delta.tool_calls {
45+
tool_call_stream_manager.process_chunks(tool_call_chunks);
9846
}
9947
if let Some(finish_reason) = &chat_choice.finish_reason {
10048
if matches!(finish_reason, FinishReason::ToolCalls) {
101-
let tool_call_states_clone = tool_call_states.clone();
102-
103-
let tool_calls_to_process = {
104-
let states_lock = tool_call_states_clone.lock().await;
105-
states_lock
106-
.iter()
107-
.map(|(_key, tool_call)| {
108-
let name = tool_call.function.name.clone();
109-
let args = tool_call.function.arguments.clone();
110-
let tool_call_clone = tool_call.clone();
111-
(name, args, tool_call_clone)
112-
})
113-
.collect::<Vec<_>>()
114-
};
115-
116-
let mut handles = Vec::new();
117-
for (name, args, tool_call_clone) in tool_calls_to_process {
118-
let response_content_clone = function_responses.clone();
119-
let handle = tokio::spawn(async move {
120-
let response_content = call_fn(&name, &args).await.unwrap();
121-
let mut function_responses_lock =
122-
response_content_clone.lock().await;
123-
function_responses_lock
124-
.push((tool_call_clone, response_content));
125-
});
126-
handles.push(handle);
127-
}
128-
129-
for handle in handles {
130-
handle.await.unwrap();
131-
}
132-
133-
let function_responses_clone = function_responses.clone();
134-
let function_responses_lock = function_responses_clone.lock().await;
135-
let mut messages: Vec<ChatCompletionRequestMessage> =
136-
vec![ChatCompletionRequestUserMessageArgs::default()
137-
.content(user_prompt)
138-
.build()?
139-
.into()];
140-
141-
let tool_calls: Vec<ChatCompletionMessageToolCall> =
142-
function_responses_lock
143-
.iter()
144-
.map(|tc| tc.0.clone())
145-
.collect();
146-
147-
let assistant_messages: ChatCompletionRequestMessage =
148-
ChatCompletionRequestAssistantMessageArgs::default()
149-
.tool_calls(tool_calls)
150-
.build()
151-
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
152-
.unwrap()
153-
.into();
154-
155-
let tool_messages: Vec<ChatCompletionRequestMessage> =
156-
function_responses_lock
157-
.iter()
158-
.map(|tc| {
159-
ChatCompletionRequestToolMessageArgs::default()
160-
.content(tc.1.to_string())
161-
.tool_call_id(tc.0.id.clone())
162-
.build()
163-
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
164-
.unwrap()
165-
.into()
166-
})
167-
.collect();
168-
169-
messages.push(assistant_messages);
170-
messages.extend(tool_messages);
171-
172-
let request = CreateChatCompletionRequestArgs::default()
173-
.max_tokens(512u32)
174-
.model("gpt-4-1106-preview")
175-
.messages(messages)
176-
.build()
177-
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
178-
179-
let mut stream = client.chat().create_stream(request).await?;
180-
181-
let mut response_content = String::new();
182-
let mut lock = stdout().lock();
183-
while let Some(result) = stream.next().await {
184-
match result {
185-
Ok(response) => {
186-
for chat_choice in response.choices.iter() {
187-
if let Some(ref content) = chat_choice.delta.content {
188-
write!(lock, "{}", content).unwrap();
189-
response_content.push_str(content);
190-
}
191-
}
192-
}
193-
Err(err) => {
194-
return Err(Box::new(err) as Box<dyn std::error::Error>);
195-
}
196-
}
197-
}
49+
is_end_with_tool_call = true;
19850
}
19951
}
20052

@@ -214,40 +66,112 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
21466
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
21567
}
21668

69+
if !is_end_with_tool_call {
70+
return Err("The response is not ended with tool call".into());
71+
}
72+
let tool_calls = tool_call_stream_manager.finish_stream();
73+
let function_responses = tool_manager.call(tool_calls.clone()).await;
74+
75+
let assistant_messages: ChatCompletionRequestMessage =
76+
ChatCompletionRequestAssistantMessageArgs::default()
77+
.tool_calls(tool_calls)
78+
.build()
79+
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
80+
.unwrap()
81+
.into();
82+
83+
let tool_messages: Vec<ChatCompletionRequestMessage> = function_responses
84+
.into_iter()
85+
.map(|res| res.into())
86+
.collect();
87+
88+
messages.push(assistant_messages);
89+
messages.extend(tool_messages);
90+
91+
let request = CreateChatCompletionRequestArgs::default()
92+
.max_tokens(512u32)
93+
.model("gpt-4-1106-preview")
94+
.messages(messages)
95+
.build()
96+
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
97+
98+
let mut stream = client.chat().create_stream(request).await?;
99+
100+
let mut response_content = String::new();
101+
let mut lock = stdout().lock();
102+
while let Some(result) = stream.next().await {
103+
match result {
104+
Ok(response) => {
105+
for chat_choice in response.choices.iter() {
106+
if let Some(ref content) = chat_choice.delta.content {
107+
write!(lock, "{}", content).unwrap();
108+
response_content.push_str(content);
109+
}
110+
}
111+
}
112+
Err(err) => {
113+
return Err(Box::new(err) as Box<dyn std::error::Error>);
114+
}
115+
}
116+
}
117+
217118
Ok(())
218119
}
219120

220-
async fn call_fn(name: &str, args: &str) -> Result<Value, Box<dyn std::error::Error>> {
221-
let mut available_functions: HashMap<&str, fn(&str, &str) -> serde_json::Value> =
222-
HashMap::new();
223-
available_functions.insert("get_current_weather", get_current_weather);
121+
#[derive(Debug, JsonSchema, Deserialize, Serialize)]
122+
enum Unit {
123+
Fahrenheit,
124+
Celsius,
125+
}
224126

225-
let function_args: serde_json::Value = args.parse().unwrap();
127+
#[derive(Debug, JsonSchema, Deserialize)]
128+
struct WeatherRequest {
129+
/// The city and state, e.g. San Francisco, CA
130+
location: String,
131+
unit: Unit,
132+
}
226133

227-
let location = function_args["location"].as_str().unwrap();
228-
let unit = function_args["unit"].as_str().unwrap_or("fahrenheit");
229-
let function = available_functions.get(name).unwrap();
230-
let function_response = function(location, unit);
231-
Ok(function_response)
134+
#[derive(Debug, Serialize)]
135+
struct WeatherResponse {
136+
location: String,
137+
temperature: String,
138+
unit: Unit,
139+
forecast: String,
232140
}
233141

234-
fn get_current_weather(location: &str, unit: &str) -> serde_json::Value {
235-
let mut rng = thread_rng();
142+
struct WeatherTool;
143+
144+
impl Tool for WeatherTool {
145+
type Args = WeatherRequest;
146+
type Output = WeatherResponse;
147+
type Error = String;
148+
149+
fn name(&self) -> String {
150+
"get_current_weather".to_string()
151+
}
152+
153+
fn description(&self) -> Option<String> {
154+
Some("Get the current weather in a given location".to_string())
155+
}
156+
157+
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
158+
let mut rng = thread_rng();
236159

237-
let temperature: i32 = rng.gen_range(20..=55);
160+
let temperature: i32 = rng.gen_range(20..=55);
238161

239-
let forecasts = [
240-
"sunny", "cloudy", "overcast", "rainy", "windy", "foggy", "snowy",
241-
];
162+
let forecasts = [
163+
"sunny", "cloudy", "overcast", "rainy", "windy", "foggy", "snowy",
164+
];
242165

243-
let forecast = forecasts.choose(&mut rng).unwrap_or(&"sunny");
166+
let forecast = forecasts.choose(&mut rng).unwrap_or(&"sunny");
244167

245-
let weather_info = json!({
246-
"location": location,
247-
"temperature": temperature.to_string(),
248-
"unit": unit,
249-
"forecast": forecast
250-
});
168+
let weather_info = WeatherResponse {
169+
location: args.location,
170+
temperature: temperature.to_string(),
171+
unit: args.unit,
172+
forecast: forecast.to_string(),
173+
};
251174

252-
weather_info
175+
Ok(weather_info)
176+
}
253177
}

examples/tool-call/Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ publish = false
77
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
88

99
[dependencies]
10-
async-openai = {path = "../../async-openai"}
10+
async-openai = { path = "../../async-openai" }
1111
rand = "0.8.5"
12+
serde = "1.0"
1213
serde_json = "1.0.135"
1314
tokio = { version = "1.43.0", features = ["full"] }
1415
futures = "0.3.31"
16+
schemars = "0.8.22"

0 commit comments

Comments
 (0)