Skip to content

Commit 638bf75

Browse files
authored
feat: Bring your own types (64bit#342)
* attribute proc macro to bring your own types * keep original fn as it is add new with _byot suffix * update macro * update macro * use macro in main crate + add test * byot: assistants * byot: vector_stores * add where_clause attribute arg * remove print * byot: files * byot: images * add stream arg to attribute * byot: chat * byot: completions * fix comment * fix * byot: audio * byot: embeddings * byot: Fine Tunning * add byot tests * byot: moderations * byot tests: moderations * byot: threads * byot tests: threads * byot: messages * byot tests: messages * byot: runs * byot tests: runs * byot: steps * byot tests: run steps * byot: vector store files * byot test: vector store files * byot: vector store file batches * byot test: vector store file batches * cargo fmt * byot: batches * byot tests: batches * format * remove AssistantFiles and related apis (/assistants/assistant_id/files/..) * byot: audit logs * byot tests: audit logs * keep non byot code checks * byot: invites * byot tests: invites * remove message files API * byot: project api keys * byot tests: project api keys * byot: project service accounts * byot tests: project service accounts * byot: project users * byot tests: project users * byot: projects * byot tests: projects * byot: uploads * byot tests: uploads * byot: users * byot tests: users * add example to demonstrate bring-your-own-types * update README * update doc * cargo fmt * update doc in lib.rs * tests passing * fix for complier warning * fix compiler #[allow(unused_mut)] * cargo fix * fix all warnings * add Voices * publish = false for all examples * specify versions
1 parent c48e62e commit 638bf75

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1133
-317
lines changed

Cargo.toml

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
[workspace]
2-
members = [ "async-openai", "examples/*" ]
2+
members = [ "async-openai", "async-openai-*", "examples/*" ]
33
# Only check / build main crates by default (check all with `--workspace`)
4-
default-members = ["async-openai"]
4+
default-members = ["async-openai", "async-openai-*"]
55
resolver = "2"
6+
7+
[workspace.package]
8+
rust-version = "1.75"

async-openai-macros/Cargo.toml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
[package]
2+
name = "async-openai-macros"
3+
version = "0.1.0"
4+
authors = ["Himanshu Neema"]
5+
keywords = ["openai", "macros", "ai"]
6+
description = "Macros for async-openai"
7+
edition = "2021"
8+
license = "MIT"
9+
homepage = "https://github.com/64bit/async-openai"
10+
repository = "https://github.com/64bit/async-openai"
11+
rust-version = { workspace = true }
12+
13+
[lib]
14+
proc-macro = true
15+
16+
[dependencies]
17+
syn = { version = "2.0", features = ["full"] }
18+
quote = "1.0"
19+
proc-macro2 = "1.0"

async-openai-macros/src/lib.rs

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
use proc_macro::TokenStream;
2+
use quote::{quote, ToTokens};
3+
use syn::{
4+
parse::{Parse, ParseStream},
5+
parse_macro_input,
6+
punctuated::Punctuated,
7+
token::Comma,
8+
FnArg, GenericParam, Generics, ItemFn, Pat, PatType, TypeParam, WhereClause,
9+
};
10+
11+
// Parse attribute arguments like #[byot(T0: Display + Debug, T1: Clone, R: Serialize)]
12+
struct BoundArgs {
13+
bounds: Vec<(String, syn::TypeParamBound)>,
14+
where_clause: Option<String>,
15+
stream: bool, // Add stream flag
16+
}
17+
18+
impl Parse for BoundArgs {
19+
fn parse(input: ParseStream) -> syn::Result<Self> {
20+
let mut bounds = Vec::new();
21+
let mut where_clause = None;
22+
let mut stream = false; // Default to false
23+
let vars = Punctuated::<syn::MetaNameValue, Comma>::parse_terminated(input)?;
24+
25+
for var in vars {
26+
let name = var.path.get_ident().unwrap().to_string();
27+
match name.as_str() {
28+
"where_clause" => {
29+
where_clause = Some(var.value.into_token_stream().to_string());
30+
}
31+
"stream" => {
32+
stream = var.value.into_token_stream().to_string().contains("true");
33+
}
34+
_ => {
35+
let bound: syn::TypeParamBound =
36+
syn::parse_str(&var.value.into_token_stream().to_string())?;
37+
bounds.push((name, bound));
38+
}
39+
}
40+
}
41+
Ok(BoundArgs {
42+
bounds,
43+
where_clause,
44+
stream,
45+
})
46+
}
47+
}
48+
49+
#[proc_macro_attribute]
50+
pub fn byot_passthrough(_args: TokenStream, item: TokenStream) -> TokenStream {
51+
item
52+
}
53+
54+
#[proc_macro_attribute]
55+
pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
56+
let bounds_args = parse_macro_input!(args as BoundArgs);
57+
let input = parse_macro_input!(item as ItemFn);
58+
let mut new_generics = Generics::default();
59+
let mut param_count = 0;
60+
61+
// Process function arguments
62+
let mut new_params = Vec::new();
63+
let args = input
64+
.sig
65+
.inputs
66+
.iter()
67+
.map(|arg| {
68+
match arg {
69+
FnArg::Receiver(receiver) => receiver.to_token_stream(),
70+
FnArg::Typed(PatType { pat, .. }) => {
71+
if let Pat::Ident(pat_ident) = &**pat {
72+
let generic_name = format!("T{}", param_count);
73+
let generic_ident =
74+
syn::Ident::new(&generic_name, proc_macro2::Span::call_site());
75+
76+
// Create type parameter with optional bounds
77+
let mut type_param = TypeParam::from(generic_ident.clone());
78+
if let Some((_, bound)) = bounds_args
79+
.bounds
80+
.iter()
81+
.find(|(name, _)| name == &generic_name)
82+
{
83+
type_param.bounds.extend(vec![bound.clone()]);
84+
}
85+
86+
new_params.push(GenericParam::Type(type_param));
87+
param_count += 1;
88+
quote! { #pat_ident: #generic_ident }
89+
} else {
90+
arg.to_token_stream()
91+
}
92+
}
93+
}
94+
})
95+
.collect::<Vec<_>>();
96+
97+
// Add R type parameter with optional bounds
98+
let generic_r = syn::Ident::new("R", proc_macro2::Span::call_site());
99+
let mut return_type_param = TypeParam::from(generic_r.clone());
100+
if let Some((_, bound)) = bounds_args.bounds.iter().find(|(name, _)| name == "R") {
101+
return_type_param.bounds.extend(vec![bound.clone()]);
102+
}
103+
new_params.push(GenericParam::Type(return_type_param));
104+
105+
// Add all generic parameters
106+
new_generics.params.extend(new_params);
107+
108+
let fn_name = &input.sig.ident;
109+
let byot_fn_name = syn::Ident::new(&format!("{}_byot", fn_name), fn_name.span());
110+
let vis = &input.vis;
111+
let block = &input.block;
112+
let attrs = &input.attrs;
113+
let asyncness = &input.sig.asyncness;
114+
115+
// Parse where clause if provided
116+
let where_clause = if let Some(where_str) = bounds_args.where_clause {
117+
match syn::parse_str::<WhereClause>(&format!("where {}", where_str.replace("\"", ""))) {
118+
Ok(where_clause) => quote! { #where_clause },
119+
Err(e) => return TokenStream::from(e.to_compile_error()),
120+
}
121+
} else {
122+
quote! {}
123+
};
124+
125+
// Generate return type based on stream flag
126+
let return_type = if bounds_args.stream {
127+
quote! { Result<::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<R, OpenAIError>> + Send>>, OpenAIError> }
128+
} else {
129+
quote! { Result<R, OpenAIError> }
130+
};
131+
132+
let expanded = quote! {
133+
#(#attrs)*
134+
#input
135+
136+
#(#attrs)*
137+
#vis #asyncness fn #byot_fn_name #new_generics (#(#args),*) -> #return_type #where_clause #block
138+
};
139+
140+
expanded.into()
141+
}

async-openai/Cargo.toml

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
[package]
22
name = "async-openai"
3-
version = "0.27.2"
3+
version = "0.28.0"
44
authors = ["Himanshu Neema"]
55
categories = ["api-bindings", "web-programming", "asynchronous"]
66
keywords = ["openai", "async", "openapi", "ai"]
77
description = "Rust library for OpenAI"
88
edition = "2021"
9-
rust-version = "1.75"
9+
rust-version = { workspace = true }
1010
license = "MIT"
1111
readme = "README.md"
1212
homepage = "https://github.com/64bit/async-openai"
@@ -23,8 +23,11 @@ native-tls = ["reqwest/native-tls"]
2323
# Remove dependency on OpenSSL
2424
native-tls-vendored = ["reqwest/native-tls-vendored"]
2525
realtime = ["dep:tokio-tungstenite"]
26+
# Bring your own types
27+
byot = []
2628

2729
[dependencies]
30+
async-openai-macros = { path = "../async-openai-macros", version = "0.1.0" }
2831
backoff = { version = "0.4.0", features = ["tokio"] }
2932
base64 = "0.22.1"
3033
futures = "0.3.31"
@@ -50,6 +53,11 @@ tokio-tungstenite = { version = "0.26.1", optional = true, default-features = fa
5053

5154
[dev-dependencies]
5255
tokio-test = "0.4.4"
56+
serde_json = "1.0"
57+
58+
[[test]]
59+
name = "bring-your-own-type"
60+
required-features = ["byot"]
5361

5462
[package.metadata.docs.rs]
5563
all-features = true

async-openai/README.md

+35-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
- [x] Organizations | Administration (partially implemented)
3838
- [x] Realtime (Beta) (partially implemented)
3939
- [x] Uploads
40+
- Bring your own custom types for Request or Response objects.
4041
- SSE streaming on available APIs
4142
- Requests (except SSE streaming) including form submissions are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits).
4243
- Ergonomic builder pattern for all request objects.
@@ -62,7 +63,7 @@ $Env:OPENAI_API_KEY='sk-...'
6263
## Realtime API
6364

6465
Only types for Realtime API are implemented, and can be enabled with feature flag `realtime`.
65-
These types may change if/when OpenAI releases official specs for them.
66+
These types were written before OpenAI released official specs.
6667

6768
## Image Generation Example
6869

@@ -108,6 +109,39 @@ async fn main() -> Result<(), Box<dyn Error>> {
108109
<sub>Scaled up for README, actual size 256x256</sub>
109110
</div>
110111

112+
## Bring Your Own Types
113+
114+
Enable methods whose input and outputs are generics with `byot` feature. It creates a new method with same name and `_byot` suffix.
115+
116+
For example, to use `serde_json::Value` as request and response type:
117+
```rust
118+
let response: Value = client
119+
.chat()
120+
.create_byot(json!({
121+
"messages": [
122+
{
123+
"role": "developer",
124+
"content": "You are a helpful assistant"
125+
},
126+
{
127+
"role": "user",
128+
"content": "What do you think about life?"
129+
}
130+
],
131+
"model": "gpt-4o",
132+
"store": false
133+
}))
134+
.await?;
135+
```
136+
137+
This can be useful in many scenarios:
138+
- To use this library with other OpenAI compatible APIs whose types don't exactly match OpenAI.
139+
- Extend existing types in this crate with new fields with `serde`.
140+
- To avoid verbose types.
141+
- To escape deserialization errors.
142+
143+
Visit [examples/bring-your-own-type](https://github.com/64bit/async-openai/tree/main/examples/bring-your-own-type) directory to learn more.
144+
111145
## Contributing
112146

113147
Thank you for taking the time to contribute and improve the project. I'd be happy to have you!

async-openai/src/assistant_files.rs

-66
This file was deleted.

async-openai/src/assistants.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
AssistantObject, CreateAssistantRequest, DeleteAssistantResponse, ListAssistantsResponse,
88
ModifyAssistantRequest,
99
},
10-
AssistantFiles, Client,
10+
Client,
1111
};
1212

1313
/// Build assistants that can call models and use tools to perform tasks.
@@ -22,12 +22,8 @@ impl<'c, C: Config> Assistants<'c, C> {
2222
Self { client }
2323
}
2424

25-
/// Assistant [AssistantFiles] API group
26-
pub fn files(&self, assistant_id: &str) -> AssistantFiles<C> {
27-
AssistantFiles::new(self.client, assistant_id)
28-
}
29-
3025
/// Create an assistant with a model and instructions.
26+
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
3127
pub async fn create(
3228
&self,
3329
request: CreateAssistantRequest,
@@ -36,13 +32,15 @@ impl<'c, C: Config> Assistants<'c, C> {
3632
}
3733

3834
/// Retrieves an assistant.
35+
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
3936
pub async fn retrieve(&self, assistant_id: &str) -> Result<AssistantObject, OpenAIError> {
4037
self.client
4138
.get(&format!("/assistants/{assistant_id}"))
4239
.await
4340
}
4441

4542
/// Modifies an assistant.
43+
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
4644
pub async fn update(
4745
&self,
4846
assistant_id: &str,
@@ -54,17 +52,19 @@ impl<'c, C: Config> Assistants<'c, C> {
5452
}
5553

5654
/// Delete an assistant.
55+
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
5756
pub async fn delete(&self, assistant_id: &str) -> Result<DeleteAssistantResponse, OpenAIError> {
5857
self.client
5958
.delete(&format!("/assistants/{assistant_id}"))
6059
.await
6160
}
6261

6362
/// Returns a list of assistants.
63+
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
6464
pub async fn list<Q>(&self, query: &Q) -> Result<ListAssistantsResponse, OpenAIError>
6565
where
6666
Q: Serialize + ?Sized,
6767
{
68-
self.client.get_with_query("/assistants", query).await
68+
self.client.get_with_query("/assistants", &query).await
6969
}
7070
}

0 commit comments

Comments
 (0)