1
- use std:: collections:: HashMap ;
2
1
use std:: error:: Error ;
3
2
use std:: io:: { stdout, Write } ;
4
- use std:: sync:: Arc ;
5
3
6
4
use async_openai:: types:: {
7
- ChatCompletionMessageToolCall , ChatCompletionRequestAssistantMessageArgs ,
8
- ChatCompletionRequestMessage , ChatCompletionRequestToolMessageArgs ,
9
- ChatCompletionRequestUserMessageArgs , ChatCompletionToolArgs , ChatCompletionToolType ,
10
- FinishReason , FunctionCall , FunctionObjectArgs ,
5
+ ChatCompletionRequestAssistantMessageArgs , ChatCompletionRequestMessage ,
6
+ ChatCompletionRequestUserMessageArgs , FinishReason ,
11
7
} ;
12
8
use async_openai:: { types:: CreateChatCompletionRequestArgs , Client } ;
9
+ use async_openai:: { Tool , ToolCallStreamManager , ToolManager } ;
13
10
use futures:: StreamExt ;
14
11
use rand:: seq:: SliceRandom ;
15
12
use rand:: { thread_rng, Rng } ;
16
- use serde_json :: { json , Value } ;
17
- use tokio :: sync :: Mutex ;
13
+ use schemars :: JsonSchema ;
14
+ use serde :: { Deserialize , Serialize } ;
18
15
19
16
#[ tokio:: main]
20
17
async fn main ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
21
18
let client = Client :: new ( ) ;
22
- let user_prompt = "What's the weather like in Boston and Atlanta?" ;
19
+ let mut messages = vec ! [ ChatCompletionRequestUserMessageArgs :: default ( )
20
+ . content( "What's the weather like in Boston and Atlanta?" )
21
+ . build( ) ?
22
+ . into( ) ] ;
23
+ let weather_tool = WeatherTool ;
24
+ let mut tool_manager = ToolManager :: new ( ) ;
25
+ tool_manager. add_tool ( weather_tool) ;
26
+ let tools = tool_manager. get_tools ( ) ;
23
27
24
28
let request = CreateChatCompletionRequestArgs :: default ( )
25
29
. max_tokens ( 512u32 )
26
30
. model ( "gpt-4-1106-preview" )
27
- . messages ( [ ChatCompletionRequestUserMessageArgs :: default ( )
28
- . content ( user_prompt)
29
- . build ( ) ?
30
- . into ( ) ] )
31
- . tools ( vec ! [ ChatCompletionToolArgs :: default ( )
32
- . r#type( ChatCompletionToolType :: Function )
33
- . function(
34
- FunctionObjectArgs :: default ( )
35
- . name( "get_current_weather" )
36
- . description( "Get the current weather in a given location" )
37
- . parameters( json!( {
38
- "type" : "object" ,
39
- "properties" : {
40
- "location" : {
41
- "type" : "string" ,
42
- "description" : "The city and state, e.g. San Francisco, CA" ,
43
- } ,
44
- "unit" : { "type" : "string" , "enum" : [ "celsius" , "fahrenheit" ] } ,
45
- } ,
46
- "required" : [ "location" ] ,
47
- } ) )
48
- . build( ) ?,
49
- )
50
- . build( ) ?] )
31
+ . messages ( messages. clone ( ) )
32
+ . tools ( tools)
51
33
. build ( ) ?;
52
34
53
35
let mut stream = client. chat ( ) . create_stream ( request) . await ?;
54
36
55
- let tool_call_states: Arc < Mutex < HashMap < ( u32 , u32 ) , ChatCompletionMessageToolCall > > > =
56
- Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
37
+ let mut tool_call_stream_manager = ToolCallStreamManager :: new ( ) ;
57
38
39
+ let mut is_end_with_tool_call = false ;
58
40
while let Some ( result) = stream. next ( ) . await {
59
41
match result {
60
42
Ok ( response) => {
61
43
for chat_choice in response. choices {
62
- let function_responses: Arc <
63
- Mutex < Vec < ( ChatCompletionMessageToolCall , Value ) > > ,
64
- > = Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ;
65
- if let Some ( tool_calls) = chat_choice. delta . tool_calls {
66
- for tool_call_chunk in tool_calls. into_iter ( ) {
67
- let key = ( chat_choice. index , tool_call_chunk. index ) ;
68
- let states = tool_call_states. clone ( ) ;
69
- let tool_call_data = tool_call_chunk. clone ( ) ;
70
-
71
- let mut states_lock = states. lock ( ) . await ;
72
- let state = states_lock. entry ( key) . or_insert_with ( || {
73
- ChatCompletionMessageToolCall {
74
- id : tool_call_data. id . clone ( ) . unwrap_or_default ( ) ,
75
- r#type : ChatCompletionToolType :: Function ,
76
- function : FunctionCall {
77
- name : tool_call_data
78
- . function
79
- . as_ref ( )
80
- . and_then ( |f| f. name . clone ( ) )
81
- . unwrap_or_default ( ) ,
82
- arguments : tool_call_data
83
- . function
84
- . as_ref ( )
85
- . and_then ( |f| f. arguments . clone ( ) )
86
- . unwrap_or_default ( ) ,
87
- } ,
88
- }
89
- } ) ;
90
- if let Some ( arguments) = tool_call_chunk
91
- . function
92
- . as_ref ( )
93
- . and_then ( |f| f. arguments . as_ref ( ) )
94
- {
95
- state. function . arguments . push_str ( arguments) ;
96
- }
97
- }
44
+ if let Some ( tool_call_chunks) = chat_choice. delta . tool_calls {
45
+ tool_call_stream_manager. process_chunks ( tool_call_chunks) ;
98
46
}
99
47
if let Some ( finish_reason) = & chat_choice. finish_reason {
100
48
if matches ! ( finish_reason, FinishReason :: ToolCalls ) {
101
- let tool_call_states_clone = tool_call_states. clone ( ) ;
102
-
103
- let tool_calls_to_process = {
104
- let states_lock = tool_call_states_clone. lock ( ) . await ;
105
- states_lock
106
- . iter ( )
107
- . map ( |( _key, tool_call) | {
108
- let name = tool_call. function . name . clone ( ) ;
109
- let args = tool_call. function . arguments . clone ( ) ;
110
- let tool_call_clone = tool_call. clone ( ) ;
111
- ( name, args, tool_call_clone)
112
- } )
113
- . collect :: < Vec < _ > > ( )
114
- } ;
115
-
116
- let mut handles = Vec :: new ( ) ;
117
- for ( name, args, tool_call_clone) in tool_calls_to_process {
118
- let response_content_clone = function_responses. clone ( ) ;
119
- let handle = tokio:: spawn ( async move {
120
- let response_content = call_fn ( & name, & args) . await . unwrap ( ) ;
121
- let mut function_responses_lock =
122
- response_content_clone. lock ( ) . await ;
123
- function_responses_lock
124
- . push ( ( tool_call_clone, response_content) ) ;
125
- } ) ;
126
- handles. push ( handle) ;
127
- }
128
-
129
- for handle in handles {
130
- handle. await . unwrap ( ) ;
131
- }
132
-
133
- let function_responses_clone = function_responses. clone ( ) ;
134
- let function_responses_lock = function_responses_clone. lock ( ) . await ;
135
- let mut messages: Vec < ChatCompletionRequestMessage > =
136
- vec ! [ ChatCompletionRequestUserMessageArgs :: default ( )
137
- . content( user_prompt)
138
- . build( ) ?
139
- . into( ) ] ;
140
-
141
- let tool_calls: Vec < ChatCompletionMessageToolCall > =
142
- function_responses_lock
143
- . iter ( )
144
- . map ( |tc| tc. 0 . clone ( ) )
145
- . collect ( ) ;
146
-
147
- let assistant_messages: ChatCompletionRequestMessage =
148
- ChatCompletionRequestAssistantMessageArgs :: default ( )
149
- . tool_calls ( tool_calls)
150
- . build ( )
151
- . map_err ( |e| Box :: new ( e) as Box < dyn std:: error:: Error > )
152
- . unwrap ( )
153
- . into ( ) ;
154
-
155
- let tool_messages: Vec < ChatCompletionRequestMessage > =
156
- function_responses_lock
157
- . iter ( )
158
- . map ( |tc| {
159
- ChatCompletionRequestToolMessageArgs :: default ( )
160
- . content ( tc. 1 . to_string ( ) )
161
- . tool_call_id ( tc. 0 . id . clone ( ) )
162
- . build ( )
163
- . map_err ( |e| Box :: new ( e) as Box < dyn std:: error:: Error > )
164
- . unwrap ( )
165
- . into ( )
166
- } )
167
- . collect ( ) ;
168
-
169
- messages. push ( assistant_messages) ;
170
- messages. extend ( tool_messages) ;
171
-
172
- let request = CreateChatCompletionRequestArgs :: default ( )
173
- . max_tokens ( 512u32 )
174
- . model ( "gpt-4-1106-preview" )
175
- . messages ( messages)
176
- . build ( )
177
- . map_err ( |e| Box :: new ( e) as Box < dyn std:: error:: Error > ) ?;
178
-
179
- let mut stream = client. chat ( ) . create_stream ( request) . await ?;
180
-
181
- let mut response_content = String :: new ( ) ;
182
- let mut lock = stdout ( ) . lock ( ) ;
183
- while let Some ( result) = stream. next ( ) . await {
184
- match result {
185
- Ok ( response) => {
186
- for chat_choice in response. choices . iter ( ) {
187
- if let Some ( ref content) = chat_choice. delta . content {
188
- write ! ( lock, "{}" , content) . unwrap ( ) ;
189
- response_content. push_str ( content) ;
190
- }
191
- }
192
- }
193
- Err ( err) => {
194
- return Err ( Box :: new ( err) as Box < dyn std:: error:: Error > ) ;
195
- }
196
- }
197
- }
49
+ is_end_with_tool_call = true ;
198
50
}
199
51
}
200
52
@@ -214,40 +66,112 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
214
66
. map_err ( |e| Box :: new ( e) as Box < dyn Error > ) ?;
215
67
}
216
68
69
+ if !is_end_with_tool_call {
70
+ return Err ( "The response is not ended with tool call" . into ( ) ) ;
71
+ }
72
+ let tool_calls = tool_call_stream_manager. finish_stream ( ) ;
73
+ let function_responses = tool_manager. call ( tool_calls. clone ( ) ) . await ;
74
+
75
+ let assistant_messages: ChatCompletionRequestMessage =
76
+ ChatCompletionRequestAssistantMessageArgs :: default ( )
77
+ . tool_calls ( tool_calls)
78
+ . build ( )
79
+ . map_err ( |e| Box :: new ( e) as Box < dyn std:: error:: Error > )
80
+ . unwrap ( )
81
+ . into ( ) ;
82
+
83
+ let tool_messages: Vec < ChatCompletionRequestMessage > = function_responses
84
+ . into_iter ( )
85
+ . map ( |res| res. into ( ) )
86
+ . collect ( ) ;
87
+
88
+ messages. push ( assistant_messages) ;
89
+ messages. extend ( tool_messages) ;
90
+
91
+ let request = CreateChatCompletionRequestArgs :: default ( )
92
+ . max_tokens ( 512u32 )
93
+ . model ( "gpt-4-1106-preview" )
94
+ . messages ( messages)
95
+ . build ( )
96
+ . map_err ( |e| Box :: new ( e) as Box < dyn std:: error:: Error > ) ?;
97
+
98
+ let mut stream = client. chat ( ) . create_stream ( request) . await ?;
99
+
100
+ let mut response_content = String :: new ( ) ;
101
+ let mut lock = stdout ( ) . lock ( ) ;
102
+ while let Some ( result) = stream. next ( ) . await {
103
+ match result {
104
+ Ok ( response) => {
105
+ for chat_choice in response. choices . iter ( ) {
106
+ if let Some ( ref content) = chat_choice. delta . content {
107
+ write ! ( lock, "{}" , content) . unwrap ( ) ;
108
+ response_content. push_str ( content) ;
109
+ }
110
+ }
111
+ }
112
+ Err ( err) => {
113
+ return Err ( Box :: new ( err) as Box < dyn std:: error:: Error > ) ;
114
+ }
115
+ }
116
+ }
117
+
217
118
Ok ( ( ) )
218
119
}
219
120
220
- async fn call_fn ( name : & str , args : & str ) -> Result < Value , Box < dyn std:: error:: Error > > {
221
- let mut available_functions: HashMap < & str , fn ( & str , & str ) -> serde_json:: Value > =
222
- HashMap :: new ( ) ;
223
- available_functions. insert ( "get_current_weather" , get_current_weather) ;
121
+ #[ derive( Debug , JsonSchema , Deserialize , Serialize ) ]
122
+ enum Unit {
123
+ Fahrenheit ,
124
+ Celsius ,
125
+ }
224
126
225
- let function_args: serde_json:: Value = args. parse ( ) . unwrap ( ) ;
127
+ #[ derive( Debug , JsonSchema , Deserialize ) ]
128
+ struct WeatherRequest {
129
+ /// The city and state, e.g. San Francisco, CA
130
+ location : String ,
131
+ unit : Unit ,
132
+ }
226
133
227
- let location = function_args[ "location" ] . as_str ( ) . unwrap ( ) ;
228
- let unit = function_args[ "unit" ] . as_str ( ) . unwrap_or ( "fahrenheit" ) ;
229
- let function = available_functions. get ( name) . unwrap ( ) ;
230
- let function_response = function ( location, unit) ;
231
- Ok ( function_response)
134
+ #[ derive( Debug , Serialize ) ]
135
+ struct WeatherResponse {
136
+ location : String ,
137
+ temperature : String ,
138
+ unit : Unit ,
139
+ forecast : String ,
232
140
}
233
141
234
- fn get_current_weather ( location : & str , unit : & str ) -> serde_json:: Value {
235
- let mut rng = thread_rng ( ) ;
142
+ struct WeatherTool ;
143
+
144
+ impl Tool for WeatherTool {
145
+ type Args = WeatherRequest ;
146
+ type Output = WeatherResponse ;
147
+ type Error = String ;
148
+
149
+ fn name ( & self ) -> String {
150
+ "get_current_weather" . to_string ( )
151
+ }
152
+
153
+ fn description ( & self ) -> Option < String > {
154
+ Some ( "Get the current weather in a given location" . to_string ( ) )
155
+ }
156
+
157
+ async fn call ( & self , args : Self :: Args ) -> Result < Self :: Output , Self :: Error > {
158
+ let mut rng = thread_rng ( ) ;
236
159
237
- let temperature: i32 = rng. gen_range ( 20 ..=55 ) ;
160
+ let temperature: i32 = rng. gen_range ( 20 ..=55 ) ;
238
161
239
- let forecasts = [
240
- "sunny" , "cloudy" , "overcast" , "rainy" , "windy" , "foggy" , "snowy" ,
241
- ] ;
162
+ let forecasts = [
163
+ "sunny" , "cloudy" , "overcast" , "rainy" , "windy" , "foggy" , "snowy" ,
164
+ ] ;
242
165
243
- let forecast = forecasts. choose ( & mut rng) . unwrap_or ( & "sunny" ) ;
166
+ let forecast = forecasts. choose ( & mut rng) . unwrap_or ( & "sunny" ) ;
244
167
245
- let weather_info = json ! ( {
246
- " location" : location,
247
- " temperature" : temperature. to_string( ) ,
248
- " unit" : unit,
249
- " forecast" : forecast
250
- } ) ;
168
+ let weather_info = WeatherResponse {
169
+ location : args . location ,
170
+ temperature : temperature. to_string ( ) ,
171
+ unit : args . unit ,
172
+ forecast : forecast. to_string ( ) ,
173
+ } ;
251
174
252
- weather_info
175
+ Ok ( weather_info)
176
+ }
253
177
}
0 commit comments