77
88use heck:: ToKebabCase ;
99use heck:: { ToSnakeCase , ToUpperCamelCase } ;
10- use proc_macro2:: TokenStream ;
10+ use proc_macro2:: { Literal , TokenStream } ;
1111use quote:: format_ident;
1212use quote:: quote;
1313use std:: collections:: HashSet ;
1414use std:: path:: PathBuf ;
1515use tauri_bindgen_core:: { Generate , GeneratorBuilder , TypeInfo , TypeInfos } ;
1616use tauri_bindgen_gen_rust:: { print_generics, BorrowMode , FnSig , RustGenerator } ;
17- use wit_parser:: { Function , FunctionResult , Interface , Type , TypeDefKind } ;
17+ use wit_parser:: { Function , Interface , Type , TypeDefKind } ;
1818
1919#[ derive( Default , Debug , Clone ) ]
2020#[ cfg_attr( feature = "clap" , derive( clap:: Args ) ) ]
@@ -252,7 +252,7 @@ impl Host {
252252
253253 let functions = functions. map ( |func| {
254254 let sig = FnSig {
255- async_ : false ,
255+ async_ : self . opts . async_ ,
256256 unsafe_ : false ,
257257 private : true ,
258258 self_arg : Some ( quote ! ( & self ) ) ,
@@ -266,7 +266,13 @@ impl Host {
266266
267267 let sized = sized. then_some ( quote ! ( : Sized ) ) ;
268268
269+ let async_trait = self
270+ . opts
271+ . async_
272+ . then_some ( quote ! { #[ :: tauri_bindgen_host:: async_trait] } ) ;
273+
269274 quote ! {
275+ #async_trait
270276 pub trait #ident #sized {
271277 #( #additional_items) *
272278 #( #functions) *
@@ -304,114 +310,145 @@ impl Host {
304310 }
305311 }
306312
307- fn print_add_to_router < ' a > (
308- & self ,
309- mod_ident : & str ,
310- functions : impl Iterator < Item = & ' a Function > ,
311- methods : impl Iterator < Item = ( & ' a str , & ' a Function ) > ,
312- ) -> TokenStream {
313- let trait_ident = format_ident ! ( "{}" , mod_ident. to_upper_camel_case( ) ) ;
314-
315- let mod_name = mod_ident. to_snake_case ( ) ;
316-
317- let functions = functions. map ( |func| {
318- let func_name = func. id . to_snake_case ( ) ;
319- let func_ident = format_ident ! ( "{}" , func_name) ;
320-
321- let params = self . print_function_params ( & func. params , & BorrowMode :: Owned ) ;
322-
323- let param_idents = func
324- . params
325- . iter ( )
326- . map ( |( ident, _) | { format_ident ! ( "{}" , ident) } ) ;
327-
328- let result = match func. result . as_ref ( ) {
329- Some ( FunctionResult :: Anon ( ty) ) => {
330- let ty = self . print_ty ( ty, & BorrowMode :: Owned ) ;
331-
332- quote ! { #ty }
333- }
334- Some ( FunctionResult :: Named ( types) ) if types. len ( ) == 1 => {
335- let ( _, ty) = & types[ 0 ] ;
336- let ty = self . print_ty ( ty, & BorrowMode :: Owned ) ;
337-
338- quote ! { #ty }
339- }
340- Some ( FunctionResult :: Named ( types) ) => {
341- let types = types. iter ( ) . map ( |( _, ty) | self . print_ty ( ty, & BorrowMode :: Owned ) ) ;
313+ fn print_router_fn_definition ( & self , mod_name : & str , func : & Function ) -> TokenStream {
314+ let func_name = func. ident . to_snake_case ( ) ;
315+ let func_ident = format_ident ! ( "{}" , func_name) ;
342316
343- quote ! { ( #( #types) , * ) }
344- }
345- _ => quote ! { ( ) } ,
346- } ;
317+ let param_decl = match func. params . len ( ) {
318+ 0 => quote ! { ( ) } ,
319+ 1 => {
320+ let ty = & func. params . first ( ) . unwrap ( ) . 1 ;
321+ let ty = self . print_ty ( ty, & BorrowMode :: Owned ) ;
322+ quote ! { #ty }
323+ }
324+ _ => {
325+ let tys = func
326+ . params
327+ . iter ( )
328+ . map ( |( _, ty) | self . print_ty ( ty, & BorrowMode :: Owned ) ) ;
329+ quote ! { ( #( #tys) , * ) }
330+ }
331+ } ;
332+
333+ let param_acc = match func. params . len ( ) {
334+ 0 => quote ! { } ,
335+ 1 => quote ! { p } ,
336+ _ => {
337+ let ids = func. params . iter ( ) . enumerate ( ) . map ( |( i, _) | {
338+ let i = Literal :: usize_unsuffixed ( i) ;
339+ quote ! { p. #i }
340+ } ) ;
341+ quote ! { #( #ids) , * }
342+ }
343+ } ;
347344
345+ if self . opts . async_ {
346+ quote ! {
347+ let get_cx = :: std:: sync:: Arc :: clone( & wrapped_get_cx) ;
348+ router. define_async(
349+ #mod_name,
350+ #func_name,
351+ move |ctx: :: tauri_bindgen_host:: ipc_router_wip:: Caller <T >, p: #param_decl| {
352+ let get_cx = get_cx. clone( ) ;
353+ Box :: pin( async move {
354+ let ctx = get_cx( ctx. data( ) ) ;
355+ Ok ( ctx. #func_ident( #param_acc) . await )
356+ } )
357+ } ) ?;
358+ }
359+ } else {
348360 quote ! {
349361 let get_cx = :: std:: sync:: Arc :: clone( & wrapped_get_cx) ;
350- router. func_wrap (
362+ router. define (
351363 #mod_name,
352364 #func_name,
353- move |ctx: :: tauri_bindgen_host:: ipc_router_wip:: Caller <T >, #params| -> :: tauri_bindgen_host :: anyhow :: Result <#result> {
365+ move |ctx: :: tauri_bindgen_host:: ipc_router_wip:: Caller <T >, p : #param_decl| {
354366 let ctx = get_cx( ctx. data( ) ) ;
355367
356- Ok ( ctx. #func_ident( #( #param_idents ) , * ) )
368+ Ok ( ctx. #func_ident( #param_acc ) )
357369 } ,
358370 ) ?;
359371 }
360- } ) ;
361-
362- let methods = methods. map ( |( resource_name, method) | {
363- let func_name = method. id . to_snake_case ( ) ;
364- let func_ident = format_ident ! ( "{}" , func_name) ;
365-
366- let params = self . print_function_params ( & method. params , & BorrowMode :: Owned ) ;
367-
368- let param_idents = method
369- . params
370- . iter ( )
371- . map ( |( ident, _) | format_ident ! ( "{}" , ident) ) ;
372-
373- let result = match method. result . as_ref ( ) {
374- Some ( FunctionResult :: Anon ( ty) ) => {
375- let ty = self . print_ty ( ty, & BorrowMode :: Owned ) ;
376-
377- quote ! { #ty }
378- }
379- Some ( FunctionResult :: Named ( types) ) if types. len ( ) == 1 => {
380- let ( _, ty) = & types[ 0 ] ;
381- let ty = self . print_ty ( ty, & BorrowMode :: Owned ) ;
372+ }
373+ }
382374
383- quote ! { #ty }
384- }
385- Some ( FunctionResult :: Named ( types) ) => {
386- let types = types
387- . iter ( )
388- . map ( |( _, ty) | self . print_ty ( ty, & BorrowMode :: Owned ) ) ;
375+ fn print_router_method_definition (
376+ & self ,
377+ mod_name : & str ,
378+ resource_name : & str ,
379+ method : & Function ,
380+ ) -> TokenStream {
381+ let func_name = method. ident . to_snake_case ( ) ;
382+ let func_ident = format_ident ! ( "{}" , func_name) ;
389383
390- quote ! { ( #( #types) , * ) }
391- }
392- _ => quote ! { ( ) } ,
393- } ;
384+ let param_decl = method
385+ . params
386+ . iter ( )
387+ . map ( |( _, ty) | self . print_ty ( ty, & BorrowMode :: Owned ) ) ;
388+
389+ let param_acc = match method. params . len ( ) {
390+ 0 => quote ! { } ,
391+ 1 => quote ! { p. 1 } ,
392+ _ => {
393+ let ids = method. params . iter ( ) . enumerate ( ) . map ( |( i, _) | {
394+ let i = Literal :: usize_unsuffixed ( i + 1 ) ;
395+ quote ! { p. #i }
396+ } ) ;
397+ quote ! { #( #ids) , * }
398+ }
399+ } ;
394400
395- let mod_name = format ! ( "{mod_name}::resource::{resource_name}" ) ;
396- let get_r_ident = format_ident ! ( "get_{}" , resource_name. to_snake_case( ) ) ;
401+ let mod_name = format ! ( "{mod_name}::resource::{resource_name}" ) ;
402+ let get_r_ident = format_ident ! ( "get_{}" , resource_name. to_snake_case( ) ) ;
397403
404+ if self . opts . async_ {
398405 quote ! {
399406 let get_cx = :: std:: sync:: Arc :: clone( & wrapped_get_cx) ;
400- router. func_wrap(
407+ router. define_async(
408+ #mod_name,
409+ #func_name,
410+ move |ctx: :: tauri_bindgen_host:: ipc_router_wip:: Caller <T >, p: ( :: tauri_bindgen_host:: ResourceId , #( #param_decl) , * ) | {
411+ let get_cx = get_cx. clone( ) ;
412+ Box :: pin( async move {
413+ let ctx = get_cx( ctx. data( ) ) ;
414+ let r = ctx. #get_r_ident( p. 0 ) ?;
415+ Ok ( r. #func_ident( #param_acc) . await )
416+ } )
417+ } ) ?;
418+ }
419+ } else {
420+ quote ! {
421+ let get_cx = :: std:: sync:: Arc :: clone( & wrapped_get_cx) ;
422+ router. define(
401423 #mod_name,
402424 #func_name,
403425 move |
404426 ctx: :: tauri_bindgen_host:: ipc_router_wip:: Caller <T >,
405- this_rid: :: tauri_bindgen_host:: ResourceId ,
406- #params
407- | -> :: tauri_bindgen_host:: anyhow:: Result <#result> {
427+ p: ( :: tauri_bindgen_host:: ResourceId , #( #param_decl) , * )
428+ | {
408429 let ctx = get_cx( ctx. data( ) ) ;
409- let r = ctx. #get_r_ident( this_rid) ?;
410-
411- Ok ( r. #func_ident( #( #param_idents) , * ) )
430+ let r = ctx. #get_r_ident( p. 0 ) ?;
431+ Ok ( r. #func_ident( #param_acc) )
412432 } ,
413433 ) ?;
414434 }
435+ }
436+ }
437+
438+ fn print_add_to_router < ' a > (
439+ & self ,
440+ mod_ident : & str ,
441+ functions : impl Iterator < Item = & ' a Function > ,
442+ methods : impl Iterator < Item = ( & ' a str , & ' a Function ) > ,
443+ ) -> TokenStream {
444+ let trait_ident = format_ident ! ( "{}" , mod_ident. to_upper_camel_case( ) ) ;
445+
446+ let mod_name = mod_ident. to_snake_case ( ) ;
447+
448+ let functions = functions. map ( |func| self . print_router_fn_definition ( & mod_name, func) ) ;
449+
450+ let methods = methods. map ( |( resource_name, method) | {
451+ self . print_router_method_definition ( & mod_name, resource_name, method)
415452 } ) ;
416453
417454 quote ! {
@@ -420,6 +457,7 @@ impl Host {
420457 get_cx: impl Fn ( & T ) -> & U + Send + Sync + ' static ,
421458 ) -> Result <( ) , :: tauri_bindgen_host:: ipc_router_wip:: Error >
422459 where
460+ T : Send + Sync + ' static ,
423461 U : #trait_ident + Send + Sync + ' static ,
424462 {
425463 let wrapped_get_cx = :: std:: sync:: Arc :: new( get_cx) ;
0 commit comments