diff --git a/README.md b/README.md index 1e01ccd..23b2c76 100644 --- a/README.md +++ b/README.md @@ -31,14 +31,11 @@ Then you're ready to develop your Restate service using Rust: ```rust use restate_sdk::prelude::*; -#[restate_sdk::service] -trait Greeter { - async fn greet(name: String) -> HandlerResult; -} - -struct GreeterImpl; +pub(crate) struct Greeter; -impl Greeter for GreeterImpl { +#[restate_sdk::service(vis = "pub(crate)", name = "Greetings")] +impl Greeter { + #[handler] async fn greet(&self, _: Context<'_>, name: String) -> HandlerResult { Ok(format!("Greetings {name}")) } @@ -50,7 +47,7 @@ async fn main() { // tracing_subscriber::fmt::init(); HttpServer::new( Endpoint::builder() - .with_service(GreeterImpl.serve()) + .with_service(Greeter.serve()) .build(), ) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) @@ -75,7 +72,7 @@ async fn test_container() { .with_max_level(tracing::Level::INFO) // Set the maximum log level .init(); - let endpoint = Endpoint::builder().bind(MyServiceImpl.serve()).build(); + let endpoint = Endpoint::builder().bind(MyService.serve()).build(); // simple test container intialization with default configuration //let test_container = TestContainer::default().start(endpoint).await.unwrap(); diff --git a/examples/counter.rs b/examples/counter.rs index 8f85bc8..f41d8c0 100644 --- a/examples/counter.rs +++ b/examples/counter.rs @@ -1,23 +1,17 @@ use restate_sdk::prelude::*; -#[restate_sdk::object] -trait Counter { - #[shared] - async fn get() -> Result; - async fn add(val: u64) -> Result; - async fn increment() -> Result; - async fn reset() -> Result<(), TerminalError>; -} - -struct CounterImpl; - const COUNT: &str = "count"; -impl Counter for CounterImpl { +struct Counter; + +#[restate_sdk::object] +impl Counter { + #[handler(shared)] async fn get(&self, ctx: SharedObjectContext<'_>) -> Result { Ok(ctx.get::(COUNT).await?.unwrap_or(0)) } + #[handler] async fn add(&self, ctx: ObjectContext<'_>, val: u64) -> Result { let current = ctx.get::(COUNT).await?.unwrap_or(0); let new = current + val; @@ -25,10 +19,12 @@ impl Counter for CounterImpl { Ok(new) } + #[handler] async fn increment(&self, ctx: ObjectContext<'_>) -> Result { self.add(ctx, 1).await } + #[handler] async fn reset(&self, ctx: ObjectContext<'_>) -> Result<(), TerminalError> { ctx.clear(COUNT); Ok(()) @@ -38,7 +34,7 @@ impl Counter for CounterImpl { #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - HttpServer::new(Endpoint::builder().bind(CounterImpl.serve()).build()) + HttpServer::new(Endpoint::builder().bind(Counter.serve()).build()) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) .await; } diff --git a/examples/cron.rs b/examples/cron.rs index eef7107..81badc4 100644 --- a/examples/cron.rs +++ b/examples/cron.rs @@ -13,21 +13,13 @@ use std::time::Duration; /// ```shell /// $ curl -v http://localhost:8080/PeriodicTask/my-periodic-task/start /// ``` -#[restate_sdk::object] -trait PeriodicTask { - /// Schedules the periodic task to start - async fn start() -> Result<(), TerminalError>; - /// Stops the periodic task - async fn stop() -> Result<(), TerminalError>; - /// Business logic of the periodic task - async fn run() -> Result<(), TerminalError>; -} - -struct PeriodicTaskImpl; +struct PeriodicTask; const ACTIVE: &str = "active"; -impl PeriodicTask for PeriodicTaskImpl { +#[restate_sdk::object] +impl PeriodicTask { + #[handler] async fn start(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> { if context .get::(ACTIVE) @@ -39,7 +31,7 @@ impl PeriodicTask for PeriodicTaskImpl { } // Schedule the periodic task - PeriodicTaskImpl::schedule_next(&context); + PeriodicTask::schedule_next(&context); // Mark the periodic task as active context.set(ACTIVE, true); @@ -47,6 +39,7 @@ impl PeriodicTask for PeriodicTaskImpl { Ok(()) } + #[handler] async fn stop(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> { // Remove the active flag context.clear(ACTIVE); @@ -54,6 +47,7 @@ impl PeriodicTask for PeriodicTaskImpl { Ok(()) } + #[handler] async fn run(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> { if context.get::(ACTIVE).await?.is_none() { // Task is inactive, do nothing @@ -64,13 +58,13 @@ impl PeriodicTask for PeriodicTaskImpl { println!("Triggered the periodic task!"); // Schedule the periodic task - PeriodicTaskImpl::schedule_next(&context); + PeriodicTask::schedule_next(&context); Ok(()) } } -impl PeriodicTaskImpl { +impl PeriodicTask { fn schedule_next(context: &ObjectContext<'_>) { // To schedule, create a client to the callee handler (in this case, we're calling ourselves) context @@ -84,7 +78,7 @@ impl PeriodicTaskImpl { #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - HttpServer::new(Endpoint::builder().bind(PeriodicTaskImpl.serve()).build()) + HttpServer::new(Endpoint::builder().bind(PeriodicTask.serve()).build()) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) .await; } diff --git a/examples/failures.rs b/examples/failures.rs index 2465544..5545d94 100644 --- a/examples/failures.rs +++ b/examples/failures.rs @@ -1,19 +1,15 @@ use rand::RngCore; use restate_sdk::prelude::*; -#[restate_sdk::service] -trait FailureExample { - #[name = "doRun"] - async fn do_run() -> Result<(), TerminalError>; -} - -struct FailureExampleImpl; - #[derive(Debug, thiserror::Error)] #[error("I'm very bad, retry me")] struct MyError; -impl FailureExample for FailureExampleImpl { +struct FailureExample; + +#[restate_sdk::service] +impl FailureExample { + #[handler(name = "doRun")] async fn do_run(&self, context: Context<'_>) -> Result<(), TerminalError> { context .run::<_, _, ()>(|| async move { @@ -32,7 +28,7 @@ impl FailureExample for FailureExampleImpl { #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - HttpServer::new(Endpoint::builder().bind(FailureExampleImpl.serve()).build()) + HttpServer::new(Endpoint::builder().bind(FailureExample.serve()).build()) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) .await; } diff --git a/examples/greeter.rs b/examples/greeter.rs index b597dbe..8d537ac 100644 --- a/examples/greeter.rs +++ b/examples/greeter.rs @@ -1,15 +1,12 @@ use restate_sdk::prelude::*; use std::convert::Infallible; -#[restate_sdk::service] -trait Greeter { - async fn greet(name: String) -> Result; -} +struct Greeter; -struct GreeterImpl; - -impl Greeter for GreeterImpl { - async fn greet(&self, _: Context<'_>, name: String) -> Result { +#[restate_sdk::service] +impl Greeter { + #[handler] + async fn greet(&self, _ctx: Context<'_>, name: String) -> Result { Ok(format!("Greetings {name}")) } } @@ -17,7 +14,7 @@ impl Greeter for GreeterImpl { #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - HttpServer::new(Endpoint::builder().bind(GreeterImpl.serve()).build()) + HttpServer::new(Endpoint::builder().bind(Greeter.serve()).build()) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) .await; } diff --git a/examples/run.rs b/examples/run.rs index a9b5d16..0accdda 100644 --- a/examples/run.rs +++ b/examples/run.rs @@ -1,14 +1,11 @@ use restate_sdk::prelude::*; use std::collections::HashMap; -#[restate_sdk::service] -trait RunExample { - async fn do_run() -> Result>, HandlerError>; -} +struct RunExample(reqwest::Client); -struct RunExampleImpl(reqwest::Client); - -impl RunExample for RunExampleImpl { +#[restate_sdk::service] +impl RunExample { + #[handler] async fn do_run( &self, context: Context<'_>, @@ -39,7 +36,7 @@ async fn main() { tracing_subscriber::fmt::init(); HttpServer::new( Endpoint::builder() - .bind(RunExampleImpl(reqwest::Client::new()).serve()) + .bind(RunExample(reqwest::Client::new()).serve()) .build(), ) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) diff --git a/examples/services/my_service.rs b/examples/services/my_service.rs index 87f2518..1f2231d 100644 --- a/examples/services/my_service.rs +++ b/examples/services/my_service.rs @@ -1,13 +1,10 @@ use restate_sdk::prelude::*; -#[restate_sdk::service] -pub trait MyService { - async fn my_handler(greeting: String) -> Result; -} - -pub struct MyServiceImpl; +pub struct MyService; -impl MyService for MyServiceImpl { +#[restate_sdk::service(vis = "pub(crate)")] +impl MyService { + #[handler] async fn my_handler(&self, _ctx: Context<'_>, greeting: String) -> Result { Ok(format!("{greeting}!")) } @@ -16,7 +13,7 @@ impl MyService for MyServiceImpl { #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - HttpServer::new(Endpoint::builder().bind(MyServiceImpl.serve()).build()) + HttpServer::new(Endpoint::builder().bind(MyService.serve()).build()) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) .await; } diff --git a/examples/services/my_virtual_object.rs b/examples/services/my_virtual_object.rs index 42813e8..635dc4c 100644 --- a/examples/services/my_virtual_object.rs +++ b/examples/services/my_virtual_object.rs @@ -1,15 +1,10 @@ use restate_sdk::prelude::*; -#[restate_sdk::object] -pub trait MyVirtualObject { - async fn my_handler(name: String) -> Result; - #[shared] - async fn my_concurrent_handler(name: String) -> Result; -} - -pub struct MyVirtualObjectImpl; +pub struct MyVirtualObject; -impl MyVirtualObject for MyVirtualObjectImpl { +#[restate_sdk::object(vis = "pub(crate)")] +impl MyVirtualObject { + #[handler] async fn my_handler( &self, ctx: ObjectContext<'_>, @@ -17,6 +12,8 @@ impl MyVirtualObject for MyVirtualObjectImpl { ) -> Result { Ok(format!("Greetings {} {}", greeting, ctx.key())) } + + #[handler(shared)] async fn my_concurrent_handler( &self, ctx: SharedObjectContext<'_>, @@ -31,7 +28,7 @@ async fn main() { tracing_subscriber::fmt::init(); HttpServer::new( Endpoint::builder() - .bind(MyVirtualObjectImpl.serve()) + .bind(MyVirtualObject.serve()) .build(), ) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) diff --git a/examples/services/my_workflow.rs b/examples/services/my_workflow.rs index c1af22d..8cf5f65 100644 --- a/examples/services/my_workflow.rs +++ b/examples/services/my_workflow.rs @@ -1,20 +1,17 @@ use restate_sdk::prelude::*; -#[restate_sdk::workflow] -pub trait MyWorkflow { - async fn run(req: String) -> Result; - #[shared] - async fn interact_with_workflow() -> Result<(), HandlerError>; -} - -pub struct MyWorkflowImpl; +pub struct MyWorkflow; -impl MyWorkflow for MyWorkflowImpl { +#[restate_sdk::workflow(vis = "pub(crate)")] +impl MyWorkflow { + #[handler] async fn run(&self, _ctx: WorkflowContext<'_>, _req: String) -> Result { // implement workflow logic here Ok(String::from("success")) } + + #[handler(shared)] async fn interact_with_workflow( &self, _ctx: SharedWorkflowContext<'_>, @@ -29,7 +26,7 @@ impl MyWorkflow for MyWorkflowImpl { #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - HttpServer::new(Endpoint::builder().bind(MyWorkflowImpl.serve()).build()) + HttpServer::new(Endpoint::builder().bind(MyWorkflow.serve()).build()) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) .await; } diff --git a/examples/tracing.rs b/examples/tracing.rs index 19f6995..dec3a9a 100644 --- a/examples/tracing.rs +++ b/examples/tracing.rs @@ -3,14 +3,11 @@ use std::time::Duration; use tracing::info; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer}; -#[restate_sdk::service] -trait Greeter { - async fn greet(name: String) -> Result; -} +struct Greeter; -struct GreeterImpl; - -impl Greeter for GreeterImpl { +#[restate_sdk::service] +impl Greeter { + #[handler] async fn greet(&self, ctx: Context<'_>, name: String) -> Result { info!("Before sleep"); ctx.sleep(Duration::from_secs(61)).await?; // More than suspension timeout to trigger replay @@ -31,7 +28,7 @@ async fn main() { .with_filter(replay_filter), ) .init(); - HttpServer::new(Endpoint::builder().bind(GreeterImpl.serve()).build()) + HttpServer::new(Endpoint::builder().bind(Greeter.serve()).build()) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) .await; } diff --git a/macros/src/ast.rs b/macros/src/ast.rs index 7a9aadf..1916ef9 100644 --- a/macros/src/ast.rs +++ b/macros/src/ast.rs @@ -11,27 +11,14 @@ // Some parts copied from https://github.com/dtolnay/thiserror/blob/39aaeb00ff270a49e3c254d7b38b10e934d3c7a5/impl/src/ast.rs // License Apache-2.0 or MIT -use syn::ext::IdentExt; use syn::parse::{Parse, ParseStream}; use syn::spanned::Spanned; -use syn::token::Comma; use syn::{ - braced, parenthesized, parse_quote, Attribute, Error, Expr, ExprLit, FnArg, GenericArgument, - Ident, Lit, Pat, PatType, Path, PathArguments, Result, ReturnType, Token, Type, Visibility, + parse_quote, Attribute, Error, Expr, ExprLit, FnArg, GenericArgument, Ident, ImplItem, + ImplItemFn, ItemImpl, Lit, Meta, Pat, PatType, PathArguments, Result, ReturnType, Type, + Visibility, }; -/// Accumulates multiple errors into a result. -/// Only use this for recoverable errors, i.e. non-parse errors. Fatal errors should early exit to -/// avoid further complications. -macro_rules! extend_errors { - ($errors: ident, $e: expr) => { - match $errors { - Ok(_) => $errors = Err($e), - Err(ref mut errors) => errors.extend($e), - } - }; -} - #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum ServiceType { Service, @@ -63,207 +50,334 @@ impl Parse for Workflow { } } -pub(crate) struct ServiceInner { - pub(crate) attrs: Vec, - pub(crate) restate_name: String, +pub(crate) struct ValidArgs { pub(crate) vis: Visibility, - pub(crate) ident: Ident, - pub(crate) handlers: Vec, + pub(crate) restate_name: Option, } -impl ServiceInner { - fn parse(service_type: ServiceType, input: ParseStream) -> Result { - let parsed_attrs = input.call(Attribute::parse_outer)?; - let vis = input.parse()?; - input.parse::()?; - let ident: Ident = input.parse()?; - let content; - braced!(content in input); - let mut rpcs = Vec::::new(); - while !content.is_empty() { - let h: Handler = content.parse()?; - - if h.is_shared && service_type == ServiceType::Service { - return Err(Error::new( - h.ident.span(), - "Service handlers cannot be annotated with #[shared]", - )); - } +impl Parse for ValidArgs { + fn parse(input: ParseStream) -> Result { + let mut vis = None; + let mut restate_name = None; - rpcs.push(h); - } - let mut ident_errors = Ok(()); - for rpc in &rpcs { - if rpc.ident == "new" { - extend_errors!( - ident_errors, - Error::new( - rpc.ident.span(), - format!( - "method name conflicts with generated fn `{}Client::new`", - ident.unraw() - ) - ) - ); - } - if rpc.ident == "serve" { - extend_errors!( - ident_errors, - Error::new( - rpc.ident.span(), - format!("method name conflicts with generated fn `{ident}::serve`") - ) - ); - } - } - ident_errors?; + let punctuated = + syn::punctuated::Punctuated::::parse_terminated(input)?; - let mut attrs = vec![]; - let mut restate_name = ident.to_string(); - for attr in parsed_attrs { - if let Some(name) = read_literal_attribute_name(&attr)? { - restate_name = name; - } else { - // Just propagate - attrs.push(attr); + for meta in punctuated { + match meta { + Meta::NameValue(name_value) if name_value.path.is_ident("vis") => { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = &name_value.value + { + let vis_str = lit_str.value(); + vis = Some(syn::parse_str::(&vis_str).map_err(|e| { + Error::new( + name_value.value.span(), + format!( + "Invalid visibility modifier '{}'. Expected \"pub\", \"pub(crate)\", etc.: {}", + vis_str, e + ), + ) + })?); + } else { + return Err(Error::new( + name_value.value.span(), + "Expected a string literal for 'vis' (e.g., vis = \"pub\", vis = \"pub(crate)\")", + )); + } + } + Meta::NameValue(name_value) if name_value.path.is_ident("name") => { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = &name_value.value + { + restate_name = Some(lit_str.value()); + } else { + return Err(Error::new( + name_value.span(), + "Expected a string literal for 'name'", + )); + } + } + bad_meta => { + return Err(Error::new( + bad_meta.span(), + "Invalid attribute format. Expected #[service(vis = pub(crate), name = \"...\")]", + )); + } } } Ok(Self { - attrs, + vis: vis.unwrap_or(Visibility::Inherited), restate_name, - vis, - ident, - handlers: rpcs, }) } } -pub(crate) struct Handler { +pub(crate) struct ServiceInner { pub(crate) attrs: Vec, + pub(crate) restate_name: String, + pub(crate) ident: Ident, + pub(crate) vis: Visibility, + pub(crate) impl_block: ItemImpl, + pub(crate) handlers: Vec, +} + +pub(crate) struct Handler { pub(crate) is_shared: bool, pub(crate) restate_name: String, pub(crate) ident: Ident, pub(crate) arg: Option, pub(crate) output_ok: Type, + #[allow(dead_code)] pub(crate) output_err: Type, } -impl Parse for Handler { - fn parse(input: ParseStream) -> Result { - let parsed_attrs = input.call(Attribute::parse_outer)?; - - input.parse::()?; - input.parse::()?; - let ident: Ident = input.parse()?; - - // Parse arguments - let content; - parenthesized!(content in input); - let mut args = Vec::new(); - let mut errors = Ok(()); - for arg in content.parse_terminated(FnArg::parse, Comma)? { - match arg { - FnArg::Typed(captured) if matches!(&*captured.pat, Pat::Ident(_)) => { - args.push(captured); - } - FnArg::Typed(captured) => { - extend_errors!( - errors, - Error::new(captured.pat.span(), "patterns aren't allowed in RPC args") - ); - } - FnArg::Receiver(_) => { - extend_errors!( - errors, - Error::new(arg.span(), "method args cannot start with self") - ); - } +impl ServiceInner { + fn parse(service_type: ServiceType, input: ParseStream) -> Result { + let mut impl_block = input.parse::()?; + let ident = match impl_block.self_ty.as_ref() { + Type::Path(path) => path.path.segments[0].ident.clone(), + bad_path => { + return Err(Error::new(bad_path.span(), "Only on impl blocks")); } - } - if args.len() > 1 { - extend_errors!( - errors, - Error::new(content.span(), "Only one input argument is supported") - ); - } - errors?; + }; - // Parse return type - let return_type: ReturnType = input.parse()?; - input.parse::()?; + let mut rpcs = Vec::new(); + for item in impl_block.items.iter_mut() { + match item { + ImplItem::Const(_) => {} + ImplItem::Fn(handler) => { + let mut is_handler = false; + let mut is_shared = false; + let mut restate_name = None; - let (ok_ty, err_ty) = match &return_type { - ReturnType::Default => return Err(Error::new( - return_type.span(), - "The return type cannot be empty, only Result or restate_sdk::prelude::HandlerResult is supported as return type", - )), - ReturnType::Type(_, ty) => { - if let Some((ok_ty, err_ty)) = extract_handler_result_parameter(ty) { - (ok_ty, err_ty) - } else { - return Err(Error::new( - return_type.span(), - "Only Result or restate_sdk::prelude::HandlerResult is supported as return type", - )); - } - } - }; + let mut attrs = Vec::with_capacity(handler.attrs.len()); + for attr in &handler.attrs { + if attr.path().is_ident("handler") { + if is_handler { + return Err(Error::new( + attr.span(), + "Multiple `#[handler]` attributes found.", + )); + } + if handler.sig.asyncness.is_none() { + return Err(Error::new( + handler.sig.fn_token.span(), + "expected async, handlers are async fn", + )); + } + is_handler = true; + (is_shared, restate_name) = + extract_handler_attributes(service_type, attr)?; + } else { + attrs.push(attr.clone()); + } + } - // Process attributes - let mut is_shared = false; - let mut restate_name = ident.to_string(); - let mut attrs = vec![]; - for attr in parsed_attrs { - if is_shared_attr(&attr) { - is_shared = true; - } else if let Some(name) = read_literal_attribute_name(&attr)? { - restate_name = name; - } else { - // Just propagate - attrs.push(attr); + if is_handler { + let handler_arg = + validate_handler_arguments(service_type, is_shared, handler)?; + + let return_type: ReturnType = handler.sig.output.clone(); + let (output_ok, output_err) = match &return_type { + ReturnType::Default => { + return Err(Error::new( + return_type.span(), + "The return type cannot be empty, only Result or restate_sdk::prelude::HandlerResult is supported as return type", + )); + } + ReturnType::Type(_, ty) => { + if let Some((ok_ty, err_ty)) = extract_handler_result_parameter(ty) + { + (ok_ty, err_ty) + } else { + return Err(Error::new( + return_type.span(), + "Only Result or restate_sdk::prelude::HandlerResult is supported as return type", + )); + } + } + }; + + handler.attrs = attrs; + + rpcs.push(Handler { + is_shared, + ident: handler.sig.ident.clone(), + restate_name: restate_name.unwrap_or(handler.sig.ident.to_string()), + arg: handler_arg, + output_ok, + output_err, + }); + } + } + bad_impl_item => { + return Err(Error::new(bad_impl_item.span(), "Only on consts and fns")); + } } } Ok(Self { - attrs, - is_shared, - restate_name, + attrs: impl_block.attrs.clone(), + restate_name: "".to_string(), ident, - arg: args.pop(), - output_ok: ok_ty, - output_err: err_ty, + vis: Visibility::Inherited, + impl_block, + handlers: rpcs, }) } } -fn is_shared_attr(attr: &Attribute) -> bool { - attr.meta - .require_path_only() - .and_then(Path::require_ident) - .is_ok_and(|i| i == "shared") +fn extract_handler_attributes( + service_type: ServiceType, + attr: &Attribute, +) -> Result<(bool, Option)> { + let mut is_shared = false; + let mut restate_name = None; + + match &attr.meta { + Meta::Path(_) => {} + Meta::List(meta_list) => { + let mut seen_shared = false; + let mut seen_name = false; + meta_list.parse_nested_meta(|meta| { + if meta.path.is_ident("shared") { + if seen_shared { + return Err(Error::new(meta.path.span(), "Duplicate `shared`")); + } + if service_type == ServiceType::Service { + return Err(Error::new( + meta.path.span(), + "Service handlers cannot be annotated with #[handler(shared)]", + )); + } + is_shared = true; + seen_shared = true; + } else if meta.path.is_ident("name") { + if seen_name { + return Err(Error::new(meta.path.span(), "Duplicate `name`")); + } + let lit: Lit = meta.value()?.parse()?; + if let Lit::Str(lit_str) = lit { + seen_name = true; + restate_name = Some(lit_str.value()); + } else { + return Err(Error::new( + lit.span(), + "Expected `name` to be a string literal", + )); + } + } else { + return Err(Error::new( + meta.path.span(), + "Invalid attribute inside #[handler]", + )); + } + Ok(()) + })?; + } + Meta::NameValue(_) => { + return Err(Error::new( + attr.meta.span(), + "Invalid attribute format for #[handler]", + )); + } + } + Ok((is_shared, restate_name)) } -fn read_literal_attribute_name(attr: &Attribute) -> Result> { - attr.meta - .require_name_value() - .ok() - .filter(|val| val.path.require_ident().is_ok_and(|i| i == "name")) - .map(|val| { - if let Expr::Lit(ExprLit { - lit: Lit::Str(ref literal), - .. - }) = &val.value - { - Ok(literal.value()) +fn validate_handler_arguments( + service_type: ServiceType, + is_shared: bool, + handler: &ImplItemFn, +) -> Result> { + let mut args_iter = handler.sig.inputs.iter(); + + match args_iter.next() { + Some(FnArg::Receiver(_)) => {} + Some(arg) => { + return Err(Error::new( + arg.span(), + "handler should have a `self` argument", + )); + } + None => { + return Err(Error::new( + handler.sig.ident.span(), + "Invalid handler arguments. It should be like (`self`, `ctx`, optional arg)", + )); + } + }; + + let valid_ctx: Ident = match (&service_type, is_shared) { + (ServiceType::Service, _) => parse_quote! { Context }, + (ServiceType::Object, true) => parse_quote! { SharedObjectContext }, + (ServiceType::Object, false) => parse_quote! { ObjectContext }, + (ServiceType::Workflow, true) => parse_quote! { SharedWorkflowContext }, + (ServiceType::Workflow, false) => parse_quote! { WorkflowContext }, + }; + + // TODO: allow the user to have unused context like _:Context in the handler + match args_iter.next() { + Some(arg @ FnArg::Typed(typed_arg)) if matches!(&*typed_arg.pat, Pat::Ident(_)) => { + if let Type::Path(type_path) = &*typed_arg.ty { + let ctx_ident = &type_path.path.segments.last().unwrap().ident; + + if ctx_ident != &valid_ctx { + let service_desc = match service_type { + ServiceType::Service => "service", + ServiceType::Object => { + if is_shared { + "shared object" + } else { + "object" + } + } + ServiceType::Workflow => { + if is_shared { + "shared workflow" + } else { + "workflow" + } + } + }; + + return Err(Error::new( + ctx_ident.span(), + format!( + "Expects `{}` type for this `{}`, but `{}` was provided.", + valid_ctx, service_desc, ctx_ident + ), + )); + } } else { - Err(Error::new( - val.span(), - "Only string literal is allowed for the 'name' attribute", - )) + return Err(Error::new( + arg.span(), + "Second argument must be one of the allowed context types", + )); } - }) - .transpose() + } + _ => { + return Err(Error::new( + handler.sig.ident.span(), + "Invalid handler arguments. It should be like (`self`, `ctx`, optional arg)", + )); + } + }; + + match args_iter.next() { + Some(FnArg::Typed(type_arg)) => Ok(Some(type_arg.clone())), + Some(FnArg::Receiver(arg)) => Err(Error::new( + arg.span(), + "Invalid handler arguments. It should be like (`self`, `ctx`, arg)", + )), + None => Ok(None), + } } fn extract_handler_result_parameter(ty: &Type) -> Option<(Type, Type)> { @@ -273,6 +387,7 @@ fn extract_handler_result_parameter(ty: &Type) -> Option<(Type, Type)> { }; let last = path.segments.last().unwrap(); + let is_result = last.ident == "Result"; let is_handler_result = last.ident == "HandlerResult"; if !is_result && !is_handler_result { diff --git a/macros/src/gen.rs b/macros/src/gen.rs index a882b4d..ad7cd07 100644 --- a/macros/src/gen.rs +++ b/macros/src/gen.rs @@ -2,7 +2,7 @@ use crate::ast::{Handler, Object, Service, ServiceInner, ServiceType, Workflow}; use proc_macro2::TokenStream as TokenStream2; use proc_macro2::{Ident, Literal}; use quote::{format_ident, quote, ToTokens}; -use syn::{Attribute, PatType, Visibility}; +use syn::{Attribute, ItemImpl, PatType, Visibility}; pub(crate) struct ServiceGenerator<'a> { pub(crate) service_ty: ServiceType, @@ -10,8 +10,9 @@ pub(crate) struct ServiceGenerator<'a> { pub(crate) service_ident: &'a Ident, pub(crate) client_ident: Ident, pub(crate) serve_ident: Ident, - pub(crate) vis: &'a Visibility, pub(crate) attrs: &'a [Attribute], + pub(crate) vis: &'a Visibility, + pub(crate) impl_block: &'a ItemImpl, pub(crate) handlers: &'a [Handler], } @@ -23,8 +24,9 @@ impl<'a> ServiceGenerator<'a> { service_ident: &s.ident, client_ident: format_ident!("{}Client", s.ident), serve_ident: format_ident!("Serve{}", s.ident), - vis: &s.vis, attrs: &s.attrs, + vis: &s.vis, + impl_block: &s.impl_block, handlers: &s.handlers, } } @@ -44,42 +46,20 @@ impl<'a> ServiceGenerator<'a> { fn trait_service(&self) -> TokenStream2 { let Self { attrs, - handlers, - vis, - service_ident, - service_ty, serve_ident, + service_ident, + impl_block, + vis, .. } = self; - let handler_fns = handlers - .iter() - .map( - |Handler { attrs, ident, arg, is_shared, output_ok, output_err, .. }| { - let args = arg.iter(); - - let ctx = match (&service_ty, is_shared) { - (ServiceType::Service, _) => quote! { ::restate_sdk::prelude::Context }, - (ServiceType::Object, true) => quote! { ::restate_sdk::prelude::SharedObjectContext }, - (ServiceType::Object, false) => quote! { ::restate_sdk::prelude::ObjectContext }, - (ServiceType::Workflow, true) => quote! { ::restate_sdk::prelude::SharedWorkflowContext }, - (ServiceType::Workflow, false) => quote! { ::restate_sdk::prelude::WorkflowContext }, - }; - - quote! { - #( #attrs )* - fn #ident(&self, context: #ctx, #( #args ),*) -> impl std::future::Future> + ::core::marker::Send; - } - }, - ); - quote! { - #( #attrs )* - #vis trait #service_ident: ::core::marker::Sized { - #( #handler_fns )* + #impl_block + #( #attrs )* + impl #service_ident { /// Returns a serving function to use with [::restate_sdk::endpoint::Builder::with_service]. - fn serve(self) -> #serve_ident { + #vis fn serve(self) -> #serve_ident { #serve_ident { service: ::std::sync::Arc::new(self) } } } @@ -88,8 +68,8 @@ impl<'a> ServiceGenerator<'a> { fn struct_serve(&self) -> TokenStream2 { let &Self { - vis, ref serve_ident, + vis, .. } = self; @@ -116,12 +96,12 @@ impl<'a> ServiceGenerator<'a> { let get_input_and_call = if handler.arg.is_some() { quote! { let (input, metadata) = ctx.input().await; - let fut = S::#handler_ident(&service_clone, (&ctx, metadata).into(), input); + let fut = #service_ident::#handler_ident(&service_clone, (&ctx, metadata).into(), input); } } else { quote! { let (_, metadata) = ctx.input::<()>().await; - let fut = S::#handler_ident(&service_clone, (&ctx, metadata).into()); + let fut = #service_ident::#handler_ident(&service_clone, (&ctx, metadata).into()); } }; @@ -139,8 +119,7 @@ impl<'a> ServiceGenerator<'a> { }); quote! { - impl ::restate_sdk::service::Service for #serve_ident - where S: #service_ident + Send + Sync + 'static, + impl ::restate_sdk::service::Service for #serve_ident<#service_ident> { type Future = ::restate_sdk::service::ServiceBoxFuture; @@ -205,8 +184,7 @@ impl<'a> ServiceGenerator<'a> { }); quote! { - impl ::restate_sdk::service::Discoverable for #serve_ident - where S: #service_ident, + impl::restate_sdk::service::Discoverable for #serve_ident<#service_ident> { fn discover() -> ::restate_sdk::discovery::Service { ::restate_sdk::discovery::Service { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8290af3..1ee0a80 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -16,15 +16,19 @@ extern crate proc_macro; mod ast; mod gen; -use crate::ast::{Object, Service, Workflow}; +use crate::ast::{Object, Service, ValidArgs, Workflow}; use crate::gen::ServiceGenerator; use proc_macro::TokenStream; use quote::ToTokens; use syn::parse_macro_input; #[proc_macro_attribute] -pub fn service(_: TokenStream, input: TokenStream) -> TokenStream { - let svc = parse_macro_input!(input as Service); +pub fn service(args: TokenStream, input: TokenStream) -> TokenStream { + let mut svc = parse_macro_input!(input as Service); + + let args = parse_macro_input!(args as ValidArgs); + svc.0.restate_name = args.restate_name.unwrap_or(svc.0.ident.to_string()); + svc.0.vis = args.vis; ServiceGenerator::new_service(&svc) .into_token_stream() @@ -32,8 +36,12 @@ pub fn service(_: TokenStream, input: TokenStream) -> TokenStream { } #[proc_macro_attribute] -pub fn object(_: TokenStream, input: TokenStream) -> TokenStream { - let svc = parse_macro_input!(input as Object); +pub fn object(args: TokenStream, input: TokenStream) -> TokenStream { + let mut svc = parse_macro_input!(input as Object); + + let args = parse_macro_input!(args as ValidArgs); + svc.0.restate_name = args.restate_name.unwrap_or(svc.0.ident.to_string()); + svc.0.vis = args.vis; ServiceGenerator::new_object(&svc) .into_token_stream() @@ -41,8 +49,12 @@ pub fn object(_: TokenStream, input: TokenStream) -> TokenStream { } #[proc_macro_attribute] -pub fn workflow(_: TokenStream, input: TokenStream) -> TokenStream { - let svc = parse_macro_input!(input as Workflow); +pub fn workflow(args: TokenStream, input: TokenStream) -> TokenStream { + let mut svc = parse_macro_input!(input as Workflow); + + let args = parse_macro_input!(args as ValidArgs); + svc.0.restate_name = args.restate_name.unwrap_or(svc.0.ident.to_string()); + svc.0.vis = args.vis; ServiceGenerator::new_workflow(&svc) .into_token_stream() diff --git a/src/context/mod.rs b/src/context/mod.rs index 2432f09..25f890f 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -445,15 +445,20 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { self.inner_context().invocation_handle(invocation_id) } - /// Create a service client. The service client is generated by the [`restate_sdk_macros::service`] macro with the same name of the trait suffixed with `Client`. + /// Create a service client. The service client is generated by the [`restate_sdk_macros::service`] macro with the same name of the struct suffixed with `Client`. /// /// ```rust,no_run /// # use std::time::Duration; /// # use restate_sdk::prelude::*; /// + /// struct MyService; + /// /// #[restate_sdk::service] - /// trait MyService { - /// async fn handle() -> HandlerResult<()>; + /// impl MyService { + /// #[handler] + /// async fn handle(&self, _ctx: Context<'_>,) -> HandlerResult<()> { + /// Ok(()) + /// } /// } /// /// # async fn handler(ctx: Context<'_>) { @@ -476,15 +481,20 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { C::create_client(self.inner_context()) } - /// Create an object client. The object client is generated by the [`restate_sdk_macros::object`] macro with the same name of the trait suffixed with `Client`. + /// Create an object client. The object client is generated by the [`restate_sdk_macros::object`] macro with the same name of the struct suffixed with `Client`. /// /// ```rust,no_run /// # use std::time::Duration; /// # use restate_sdk::prelude::*; /// + /// struct MyObject; + /// /// #[restate_sdk::object] - /// trait MyObject { - /// async fn handle() -> HandlerResult<()>; + /// impl MyObject { + /// #[handler] + /// async fn handle(&self, _ctx: ObjectContext<'_>,) -> HandlerResult<()> { + /// Ok(()) + /// } /// } /// /// # async fn handler(ctx: Context<'_>) { @@ -507,15 +517,20 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { C::create_client(self.inner_context(), key.into()) } - /// Create an workflow client. The workflow client is generated by the [`restate_sdk_macros::workflow`] macro with the same name of the trait suffixed with `Client`. + /// Create an workflow client. The workflow client is generated by the [`restate_sdk_macros::workflow`] macro with the same name of the struct suffixed with `Client`. /// /// ```rust,no_run /// # use std::time::Duration; /// # use restate_sdk::prelude::*; /// + /// struct MyWorkflow; + /// /// #[restate_sdk::workflow] - /// trait MyWorkflow { - /// async fn handle() -> HandlerResult<()>; + /// impl MyWorkflow { + /// #[handler] + /// async fn handle(&self, _ctx: WorkflowContext<'_>,) -> HandlerResult<()> { + /// Ok(()) + /// } /// } /// /// # async fn handler(ctx: Context<'_>) { diff --git a/src/http_server.rs b/src/http_server.rs index fb78a7a..c9fc095 100644 --- a/src/http_server.rs +++ b/src/http_server.rs @@ -9,9 +9,9 @@ //! ```rust,no_run //! # #[path = "../examples/services/mod.rs"] //! # mod services; -//! # use services::my_service::{MyService, MyServiceImpl}; -//! # use services::my_virtual_object::{MyVirtualObject, MyVirtualObjectImpl}; -//! # use services::my_workflow::{MyWorkflow, MyWorkflowImpl}; +//! # use services::my_service::MyService; +//! # use services::my_virtual_object::MyVirtualObject; +//! # use services::my_workflow::MyWorkflow; //! use restate_sdk::endpoint::Endpoint; //! use restate_sdk::http_server::HttpServer; //! @@ -19,9 +19,9 @@ //! async fn main() { //! HttpServer::new( //! Endpoint::builder() -//! .bind(MyServiceImpl.serve()) -//! .bind(MyVirtualObjectImpl.serve()) -//! .bind(MyWorkflowImpl.serve()) +//! .bind(MyService.serve()) +//! .bind(MyVirtualObject.serve()) +//! .bind(MyWorkflow.serve()) //! .build(), //! ) //! .listen_and_serve("0.0.0.0:9080".parse().unwrap()) @@ -39,7 +39,7 @@ //! ```rust,no_run //! # #[path = "../examples/services/mod.rs"] //! # mod services; -//! # use services::my_service::{MyService, MyServiceImpl}; +//! # use services::my_service::MyService; //! # use restate_sdk::endpoint::Endpoint; //! # use restate_sdk::http_server::HttpServer; //! # @@ -47,7 +47,7 @@ //! # async fn main() { //! HttpServer::new( //! Endpoint::builder() -//! .bind(MyServiceImpl.serve()) +//! .bind(MyService.serve()) //! .identity_key("publickeyv1_w7YHemBctH5Ck2nQRQ47iBBqhNHy4FV7t2Usbye2A6f") //! .unwrap() //! .build(), diff --git a/src/lib.rs b/src/lib.rs index 533e192..b7e0fa0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,40 +46,38 @@ //! // The prelude contains all the imports you need to get started //! use restate_sdk::prelude::*; //! -//! // Define the service using Rust traits -//! #[restate_sdk::service] -//! trait MyService { -//! async fn my_handler(greeting: String) -> Result; -//! } -//! -//! // Implement the service -//! struct MyServiceImpl; -//! impl MyService for MyServiceImpl { +//! struct MyService; //! +//! #[restate_sdk::service] +//! impl MyService { +//! #[handler] //! async fn my_handler(&self, ctx: Context<'_>, greeting: String) -> Result { //! Ok(format!("{greeting}!")) //! } -//! //! } //! //! // Start the HTTP server to expose services //! #[tokio::main] //! async fn main() { -//! HttpServer::new(Endpoint::builder().bind(MyServiceImpl.serve()).build()) +//! HttpServer::new(Endpoint::builder().bind(MyService.serve()).build()) //! .listen_and_serve("0.0.0.0:9080".parse().unwrap()) //! .await; //! } //! ``` //! +//! - Create a Concreate type (e.g. a struct) that you want to be a service //! - Specify that you want to create a service by using the [`#[restate_sdk::service]` macro](restate_sdk_macros::service). -//! - Create a trait with the service handlers. -//! - Handlers can accept zero or one parameter and return a [`Result`]. +//! - This macro has other attributes like `vis = "pub"` and `name = "my_service"` to control the +//! visibility of the generated `Serve` and `Client` code and to override the service name +//! respectively. +//! - Mark struct methods as handlers, you should use `#[handler]` macro which also has the +//! `name = "my_handler"` attribute to override the handler name. +//! - Handlers are `async` methods +//! - The first parameter of a handler after `&self` is always a [`Context`](crate::context::Context) to interact with Restate. +//! The SDK stores the actions you do on the context in the Restate journal to make them durable. Then it can accept zero or one parameter and return a [`Result`]. //! - The type of the input parameter of the handler needs to implement [`Serialize`](crate::serde::Deserialize) and [`Deserialize`](crate::serde::Deserialize). See [`crate::serde`]. //! - The Result contains the return value or a [`HandlerError`][crate::errors::HandlerError], which can be a [`TerminalError`](crate::errors::TerminalError) or any other Rust's [`std::error::Error`]. -//! - The service handler can now be called at `/MyService/myHandler`. You can optionally override the handler name used via `#[name = "myHandler"]`. More details on handler invocations can be found in the [docs](https://docs.restate.dev/invoke/http). -//! - Implement the trait on a concrete type, for example on a struct. -//! - The first parameter of a handler after `&self` is always a [`Context`](crate::context::Context) to interact with Restate. -//! The SDK stores the actions you do on the context in the Restate journal to make them durable. +//! - The service handler can now be called at `/MyService/myHandler`. More details on handler invocations can be found in the [docs](https://docs.restate.dev/invoke/http). //! - Finally, create an HTTP endpoint and bind the service(s) to it. Listen on the specified port (here 9080) for connections and requests. //! //! ## Virtual Objects @@ -88,17 +86,12 @@ //! ```rust,no_run //!use restate_sdk::prelude::*; //! -//! #[restate_sdk::object] -//! pub trait MyVirtualObject { -//! async fn my_handler(name: String) -> Result; -//! #[shared] -//! async fn my_concurrent_handler(name: String) -> Result; -//! } +//! pub struct MyVirtualObject; //! -//! pub struct MyVirtualObjectImpl; -//! -//! impl MyVirtualObject for MyVirtualObjectImpl { +//! #[restate_sdk::object] +//! impl MyVirtualObject { //! +//! #[handler] //! async fn my_handler( //! &self, //! ctx: ObjectContext<'_>, @@ -107,6 +100,7 @@ //! Ok(format!("{} {}", greeting, ctx.key())) //! } //! +//! #[handler(shared)] //! async fn my_concurrent_handler( //! &self, //! ctx: SharedObjectContext<'_>, @@ -121,7 +115,7 @@ //! async fn main() { //! HttpServer::new( //! Endpoint::builder() -//! .bind(MyVirtualObjectImpl.serve()) +//! .bind(MyVirtualObject.serve()) //! .build(), //! ) //! .listen_and_serve("0.0.0.0:9080".parse().unwrap()) @@ -130,9 +124,12 @@ //! ``` //! //! - Specify that you want to create a Virtual Object by using the [`#[restate_sdk::object]` macro](restate_sdk_macros::object). -//! - The first argument of each handler must be the [`ObjectContext`](crate::context::ObjectContext) parameter. Handlers with the `ObjectContext` parameter can write to the K/V state store. Only one handler can be active at a time per object, to ensure consistency. +//! You can use also the `vis` and the `name` attributes as with the [`#[restate_sdk::service]` macro](restate_sdk_macros::service). +//! - Object Handlers has additional attribute beside the `name`, it's the `shared` attribute.
An +//! example would be `#[handler(shared, name = "my_handler")]` +//! - The first argument of each handler after `&self`, must be the [`ObjectContext`](crate::context::ObjectContext) parameter. Handlers with the `ObjectContext` parameter can write to the K/V state store. Only one handler can be active at a time per object, to ensure consistency. //! - You can retrieve the key of the object you are in via [`ObjectContext.key`]. -//! - If you want to have a handler that executes concurrently to the others and doesn't have write access to the K/V state, add `#[shared]` to the handler definition in the trait. +//! - If you want to have a handler that executes concurrently to the others and doesn't have write access to the K/V state, use the `shared` attribute like `#[handler(shared)]`. //! Shared handlers need to use the [`SharedObjectContext`](crate::context::SharedObjectContext). //! You can use these handlers, for example, to read K/V state and expose it to the outside world, or to interact with the blocking handler and resolve awakeables etc. //! @@ -143,23 +140,19 @@ //! ```rust,no_run //! use restate_sdk::prelude::*; //! -//! #[restate_sdk::workflow] -//! pub trait MyWorkflow { -//! async fn run(req: String) -> Result; -//! #[shared] -//! async fn interact_with_workflow() -> Result<(), HandlerError>; -//! } -//! -//! pub struct MyWorkflowImpl; +//! pub struct MyWorkflow; //! -//! impl MyWorkflow for MyWorkflowImpl { +//! #[restate_sdk::workflow(vis = "pub")] +//! impl MyWorkflow { //! +//! #[handler] //! async fn run(&self, ctx: WorkflowContext<'_>, req: String) -> Result { //! //! implement workflow logic here //! //! Ok(String::from("success")) //! } //! +//! #[handler(shared)] //! async fn interact_with_workflow(&self, ctx: SharedWorkflowContext<'_>) -> Result<(), HandlerError> { //! //! implement interaction logic here //! //! e.g. resolve a promise that the workflow is waiting on @@ -171,17 +164,19 @@ //! //! #[tokio::main] //! async fn main() { -//! HttpServer::new(Endpoint::builder().bind(MyWorkflowImpl.serve()).build()) +//! HttpServer::new(Endpoint::builder().bind(MyWorkflow.serve()).build()) //! .listen_and_serve("0.0.0.0:9080".parse().unwrap()) //! .await; //! } //! ``` //! -//! - Specify that you want to create a Workflow by using the [`#[restate_sdk::workflow]` macro](workflow). +//! - Specify that you want to create a Workflow by using the [`#[restate_sdk::workflow]` macro](workflow).
+//! It also supports the `name` and the `vis` attributes +//! - Workflow Handlers supports the `name` and the `shared` attributes. //! - The workflow needs to have a `run` handler. -//! - The first argument of the `run` handler must be the [`WorkflowContext`](crate::context::WorkflowContext) parameter. -//! The `WorkflowContext` parameter is used to interact with Restate. -//! The `run` handler executes exactly once per workflow instance. +//! - The first argument of the `run` handler after the `&self`, must be the +//! [`WorkflowContext`](crate::context::WorkflowContext) parameter. The `WorkflowContext` parameter is used to +//! interact with Restate. The `run` handler executes exactly once per workflow instance. //! - The other handlers of the workflow are used to interact with the workflow: either query it, or signal it. //! They use the [`SharedWorkflowContext`](crate::context::SharedWorkflowContext) to interact with the SDK. //! These handlers can run concurrently with the run handler and can still be called after the run handler has finished. @@ -233,50 +228,53 @@ pub mod serde; /// ```rust,no_run /// use restate_sdk::prelude::*; /// +/// struct Greeter; +/// /// #[restate_sdk::service] -/// trait Greeter { -/// async fn greet(name: String) -> Result; +/// impl Greeter { +/// # #[handler] +/// # async fn greet(&self, _ctx: Context<'_>, name: String) -> Result { +/// # unimplemented!() +/// # } /// } /// ``` /// -/// This macro accepts a `trait` as input, and generates as output: +/// This macro accepts an `impl` as input, and it will: /// -/// * A trait with the same name, that you should implement on your own concrete type (e.g. `struct`), e.g.: +/// * validate that each service handler has the the appropriate [`Context`](crate::prelude::Context), to interact with Restate. /// /// ```rust,no_run /// # use restate_sdk::prelude::*; -/// # #[restate_sdk::service] -/// # trait Greeter { -/// # async fn greet(name: String) -> Result; -/// # } -/// struct GreeterImpl; -/// impl Greeter for GreeterImpl { -/// async fn greet(&self, _: Context<'_>, name: String) -> Result { +/// +/// struct Greeter; +/// +/// #[restate_sdk::service] +/// impl Greeter { +/// #[handler] +/// async fn greet(&self, _ctx: Context<'_>, name: String) -> Result { /// Ok(format!("Greetings {name}")) /// } /// } /// ``` /// -/// This trait will additionally contain, for each handler, the appropriate [`Context`](crate::prelude::Context), to interact with Restate. -/// -/// * An implementation of the [`Service`](crate::service::Service) trait, to bind the service in the [`Endpoint`](crate::prelude::Endpoint) and expose it: +/// * provide an implementation of the [`Service`](crate::service::Service) trait, to bind the service in the [`Endpoint`](crate::prelude::Endpoint) and expose it: /// /// ```rust,no_run /// # use restate_sdk::prelude::*; +/// # +/// # struct Greeter; +/// # /// # #[restate_sdk::service] -/// # trait Greeter { -/// # async fn greet(name: String) -> HandlerResult; -/// # } -/// # struct GreeterImpl; -/// # impl Greeter for GreeterImpl { -/// # async fn greet(&self, _: Context<'_>, name: String) -> HandlerResult { -/// # Ok(format!("Greetings {name}")) -/// # } +/// # impl Greeter { +/// # #[handler] +/// # async fn greet(&self, _ctx: Context<'_>, name: String) -> Result { +/// # Ok(format!("Greetings {name}")) +/// # } /// # } /// let endpoint = Endpoint::builder() /// // .serve() returns the implementation of Service used by the SDK /// // to bind your struct to the endpoint -/// .bind(GreeterImpl.serve()) +/// .bind(Greeter.serve()) /// .build(); /// ``` /// @@ -284,9 +282,15 @@ pub mod serde; /// /// ```rust,no_run /// # use restate_sdk::prelude::*; +/// # +/// # struct Greeter; +/// # /// # #[restate_sdk::service] -/// # trait Greeter { -/// # async fn greet(name: String) -> HandlerResult; +/// # impl Greeter { +/// # #[handler] +/// # async fn greet(&self, _ctx: Context<'_>, name: String) -> Result { +/// # Ok(format!("Greetings {name}")) +/// # } /// # } /// # async fn example(ctx: Context<'_>) -> Result<(), TerminalError> { /// let result = ctx @@ -298,17 +302,28 @@ pub mod serde; /// # } /// ``` /// -/// Methods of this trait can accept either no parameter, or one parameter implementing [`Deserialize`](crate::serde::Deserialize). -/// The return value MUST always be a `Result`. Down the hood, the error type is always converted to [`HandlerError`](crate::prelude::HandlerError) for the SDK to distinguish between terminal and retryable errors. For more details, check the [`HandlerError`](crate::prelude::HandlerError) doc. +/// Handler Methods of must have a `&self` and an appropriate [`Context`](crate::prelude::Context) +/// and they can accept either no parameter, or one parameter implementing +/// [`Deserialize`](crate::serde::Deserialize). +/// +/// The return value MUST always be a `Result`. Down the hood, the error type is always converted +/// to [`HandlerError`](crate::prelude::HandlerError) for the SDK to distinguish between terminal +/// and retryable errors. For more details, check the +/// [`HandlerError`](crate::prelude::HandlerError) doc. /// /// When invoking the service through Restate, the method name should be used as handler name, that is: /// /// ```rust,no_run /// use restate_sdk::prelude::*; /// +/// struct Greeter; +/// /// #[restate_sdk::service] -/// trait Greeter { -/// async fn my_greet(name: String) -> Result; +/// impl Greeter { +/// #[handler] +/// async fn my_greet(&self, _ctx: Context<'_>, name: String) -> Result { +/// Ok(format!("Greetings {name}")) +/// } /// } /// ``` /// @@ -318,12 +333,15 @@ pub mod serde; /// ```rust,no_run /// use restate_sdk::prelude::*; /// -/// #[restate_sdk::service] -/// #[name = "greeter"] -/// trait Greeter { -/// // You can invoke this handler with `http:///greeter/myGreet` -/// #[name = "myGreet"] -/// async fn my_greet(name: String) -> Result; +/// +/// struct Greeter; +/// +/// #[restate_sdk::service(name = "greeter")] +/// impl Greeter { +/// #[handler(name = "myGreet")] +/// async fn greet(&self, _ctx: Context<'_>, name: String) -> Result { +/// Ok(format!("Greetings {name}")) +/// } /// } /// ``` pub use restate_sdk_macros::service; @@ -334,16 +352,23 @@ pub use restate_sdk_macros::service; /// /// ## Shared handlers /// -/// To define a shared handler, simply annotate the handler with the `#[shared]` annotation: +/// To define a shared handler, simply annotate the handler with the `#[handler(shared)]` annotation: /// /// ```rust,no_run /// use restate_sdk::prelude::*; /// -/// #[restate_sdk::object] -/// trait Counter { -/// async fn add(val: u64) -> Result; -/// #[shared] -/// async fn get() -> Result; +/// pub struct Counter; +/// +/// #[restate_sdk::object(vis = "pub")] +/// impl Counter { +/// #[handler] +/// async fn add(&self, _ctx: ObjectContext<'_>, val: u64) -> Result { +/// unimplemented!() +/// } +/// #[handler(shared)] +/// async fn get(&self, _ctx: SharedObjectContext<'_>,) -> Result { +/// unimplemented!() +/// } /// } /// ``` pub use restate_sdk_macros::object; @@ -380,19 +405,12 @@ pub use restate_sdk_macros::object; /// ```rust,no_run /// use restate_sdk::prelude::*; /// -/// #[restate_sdk::workflow] -/// pub trait SignupWorkflow { -/// async fn run(req: String) -> Result; -/// #[shared] -/// async fn click(click_secret: String) -> Result<(), HandlerError>; -/// #[shared] -/// async fn get_status() -> Result; -/// } +/// pub struct SignupWorkflow; /// -/// pub struct SignupWorkflowImpl; -/// -/// impl SignupWorkflow for SignupWorkflowImpl { +/// #[restate_sdk::workflow] +/// impl SignupWorkflow { /// +/// #[handler] /// async fn run(&self, mut ctx: WorkflowContext<'_>, email: String) -> Result { /// let secret = ctx.rand_uuid().to_string(); /// ctx.run(|| send_email_with_link(email.clone(), secret.clone())).await?; @@ -404,11 +422,13 @@ pub use restate_sdk_macros::object; /// Ok(click_secret == secret) /// } /// +/// #[handler(shared)] /// async fn click(&self, ctx: SharedWorkflowContext<'_>, click_secret: String) -> Result<(), HandlerError> { /// ctx.resolve_promise::("email.clicked", click_secret); /// Ok(()) /// } /// +/// #[handler(shared)] /// async fn get_status(&self, ctx: SharedWorkflowContext<'_>) -> Result { /// Ok(ctx.get("status").await?.unwrap_or("unknown".to_string())) /// } @@ -420,7 +440,7 @@ pub use restate_sdk_macros::object; /// /// #[tokio::main] /// async fn main() { -/// HttpServer::new(Endpoint::builder().bind(SignupWorkflowImpl.serve()).build()) +/// HttpServer::new(Endpoint::builder().bind(SignupWorkflow.serve()).build()) /// .listen_and_serve("0.0.0.0:9080".parse().unwrap()) /// .await; /// } @@ -435,7 +455,7 @@ pub use restate_sdk_macros::object; /// /// ## Shared handlers /// -/// To define a shared handler, simply annotate the handler with the `#[shared]` annotation: +/// To define a shared handler, simply annotate the handler with the `#[handler(shared)]` annotation: /// /// ### Querying workflows /// diff --git a/test-env/tests/test_container.rs b/test-env/tests/test_container.rs index 647cf83..be30c41 100644 --- a/test-env/tests/test_container.rs +++ b/test-env/tests/test_container.rs @@ -4,29 +4,52 @@ use restate_sdk_test_env::TestContainer; use tracing::info; // Should compile -#[restate_sdk::service] -trait MyService { - async fn my_handler() -> HandlerResult; -} +pub(crate) struct MyObject; + +#[allow(dead_code)] +#[restate_sdk::object(vis = "pub(crate)")] +impl MyObject { + #[handler] + async fn my_handler(&self, _ctx: ObjectContext<'_>, _input: String) -> HandlerResult { + unimplemented!() + } -#[restate_sdk::object] -trait MyObject { - async fn my_handler(input: String) -> HandlerResult; - #[shared] - async fn my_shared_handler(input: String) -> HandlerResult; + #[handler(shared)] + async fn my_shared_handler( + &self, + _ctx: SharedObjectContext<'_>, + _input: String, + ) -> HandlerResult { + unimplemented!() + } } -#[restate_sdk::workflow] -trait MyWorkflow { - async fn my_handler(input: String) -> HandlerResult; - #[shared] - async fn my_shared_handler(input: String) -> HandlerResult; +pub(crate) struct MyWorkflow; + +#[allow(dead_code)] +#[restate_sdk::workflow(vis = "pub(crate)")] +impl MyWorkflow { + #[handler] + async fn my_handler(&self, _ctx: WorkflowContext<'_>, _input: String) -> HandlerResult { + unimplemented!() + } + + #[handler(shared)] + async fn my_shared_handler( + &self, + _ctx: SharedWorkflowContext<'_>, + _input: String, + ) -> HandlerResult { + unimplemented!() + } } -struct MyServiceImpl; +pub(crate) struct MyService; -impl MyService for MyServiceImpl { - async fn my_handler(&self, _: Context<'_>) -> HandlerResult { +#[restate_sdk::service(vis = "pub(crate)")] +impl MyService { + #[handler] + async fn my_handler(&self, _ctx: Context<'_>) -> HandlerResult { let result = "hello!"; Ok(result.to_string()) } @@ -38,7 +61,7 @@ async fn test_container() { .with_max_level(tracing::Level::INFO) // Set the maximum log level .init(); - let endpoint = Endpoint::builder().bind(MyServiceImpl.serve()).build(); + let endpoint = Endpoint::builder().bind(MyService.serve()).build(); // simple test container intialization with default configuration //let test_container = TestContainer::default().start(endpoint).await.unwrap(); diff --git a/test-services/src/awakeable_holder.rs b/test-services/src/awakeable_holder.rs index 1d28650..ae7c5d8 100644 --- a/test-services/src/awakeable_holder.rs +++ b/test-services/src/awakeable_holder.rs @@ -1,31 +1,23 @@ use restate_sdk::prelude::*; -#[restate_sdk::object] -#[name = "AwakeableHolder"] -pub(crate) trait AwakeableHolder { - #[name = "hold"] - async fn hold(id: String) -> HandlerResult<()>; - #[name = "hasAwakeable"] - #[shared] - async fn has_awakeable() -> HandlerResult; - #[name = "unlock"] - async fn unlock(payload: String) -> HandlerResult<()>; -} - -pub(crate) struct AwakeableHolderImpl; +pub(crate) struct AwakeableHolder; const ID: &str = "id"; -impl AwakeableHolder for AwakeableHolderImpl { +#[restate_sdk::object(vis = "pub(crate)", name = "AwakeableHolder")] +impl AwakeableHolder { + #[handler(name = "hold")] async fn hold(&self, context: ObjectContext<'_>, id: String) -> HandlerResult<()> { context.set(ID, id); Ok(()) } + #[handler(shared, name = "hasAwakeable")] async fn has_awakeable(&self, context: SharedObjectContext<'_>) -> HandlerResult { Ok(context.get::(ID).await?.is_some()) } + #[handler(name = "unlock")] async fn unlock(&self, context: ObjectContext<'_>, payload: String) -> HandlerResult<()> { let k: String = context.get(ID).await?.ok_or_else(|| { TerminalError::new(format!( diff --git a/test-services/src/block_and_wait_workflow.rs b/test-services/src/block_and_wait_workflow.rs index e9b092f..591aa79 100644 --- a/test-services/src/block_and_wait_workflow.rs +++ b/test-services/src/block_and_wait_workflow.rs @@ -1,24 +1,13 @@ use restate_sdk::prelude::*; -#[restate_sdk::workflow] -#[name = "BlockAndWaitWorkflow"] -pub(crate) trait BlockAndWaitWorkflow { - #[name = "run"] - async fn run(input: String) -> HandlerResult; - #[name = "unblock"] - #[shared] - async fn unblock(output: String) -> HandlerResult<()>; - #[name = "getState"] - #[shared] - async fn get_state() -> HandlerResult>>; -} - -pub(crate) struct BlockAndWaitWorkflowImpl; +pub(crate) struct BlockAndWaitWorkflow; const MY_PROMISE: &str = "my-promise"; const MY_STATE: &str = "my-state"; -impl BlockAndWaitWorkflow for BlockAndWaitWorkflowImpl { +#[restate_sdk::workflow(vis = "pub(crate)", name = "BlockAndWaitWorkflow")] +impl BlockAndWaitWorkflow { + #[handler(name = "run")] async fn run(&self, context: WorkflowContext<'_>, input: String) -> HandlerResult { context.set(MY_STATE, input); @@ -31,6 +20,7 @@ impl BlockAndWaitWorkflow for BlockAndWaitWorkflowImpl { Ok(promise) } + #[handler(shared, name = "unblock")] async fn unblock( &self, context: SharedWorkflowContext<'_>, @@ -40,6 +30,7 @@ impl BlockAndWaitWorkflow for BlockAndWaitWorkflowImpl { Ok(()) } + #[handler(shared, name = "getState")] async fn get_state( &self, context: SharedWorkflowContext<'_>, diff --git a/test-services/src/cancel_test.rs b/test-services/src/cancel_test.rs index 768b951..a3de272 100644 --- a/test-services/src/cancel_test.rs +++ b/test-services/src/cancel_test.rs @@ -12,20 +12,13 @@ pub(crate) enum BlockingOperation { Awakeable, } -#[restate_sdk::object] -#[name = "CancelTestRunner"] -pub(crate) trait CancelTestRunner { - #[name = "startTest"] - async fn start_test(op: Json) -> HandlerResult<()>; - #[name = "verifyTest"] - async fn verify_test() -> HandlerResult; -} - -pub(crate) struct CancelTestRunnerImpl; +pub(crate) struct CancelTestRunner; const CANCELED: &str = "canceled"; -impl CancelTestRunner for CancelTestRunnerImpl { +#[restate_sdk::object(vis = "pub(crate)", name = "CancelTestRunner")] +impl CancelTestRunner { + #[handler(name = "startTest")] async fn start_test( &self, context: ObjectContext<'_>, @@ -43,23 +36,17 @@ impl CancelTestRunner for CancelTestRunnerImpl { } } + #[handler(name = "verifyTest")] async fn verify_test(&self, context: ObjectContext<'_>) -> HandlerResult { Ok(context.get::(CANCELED).await?.unwrap_or(false)) } } -#[restate_sdk::object] -#[name = "CancelTestBlockingService"] -pub(crate) trait CancelTestBlockingService { - #[name = "block"] - async fn block(op: Json) -> HandlerResult<()>; - #[name = "isUnlocked"] - async fn is_unlocked() -> HandlerResult<()>; -} - -pub(crate) struct CancelTestBlockingServiceImpl; +pub(crate) struct CancelTestBlockingService; -impl CancelTestBlockingService for CancelTestBlockingServiceImpl { +#[restate_sdk::object(vis = "pub(crate)", name = "CancelTestBlockingService")] +impl CancelTestBlockingService { + #[handler(name = "block")] async fn block( &self, context: ObjectContext<'_>, @@ -91,7 +78,8 @@ impl CancelTestBlockingService for CancelTestBlockingServiceImpl { Ok(()) } - async fn is_unlocked(&self, _: ObjectContext<'_>) -> HandlerResult<()> { + #[handler(name = "isUnlocked")] + async fn is_unlocked(&self, _ctx: ObjectContext<'_>) -> HandlerResult<()> { // no-op Ok(()) } diff --git a/test-services/src/counter.rs b/test-services/src/counter.rs index 0b51d86..a869463 100644 --- a/test-services/src/counter.rs +++ b/test-services/src/counter.rs @@ -9,29 +9,18 @@ pub(crate) struct CounterUpdateResponse { new_value: u64, } -#[restate_sdk::object] -#[name = "Counter"] -pub(crate) trait Counter { - #[name = "add"] - async fn add(val: u64) -> HandlerResult>; - #[name = "addThenFail"] - async fn add_then_fail(val: u64) -> HandlerResult<()>; - #[shared] - #[name = "get"] - async fn get() -> HandlerResult; - #[name = "reset"] - async fn reset() -> HandlerResult<()>; -} - -pub(crate) struct CounterImpl; +pub(crate) struct Counter; const COUNT: &str = "counter"; -impl Counter for CounterImpl { +#[restate_sdk::object(vis = "pub(crate)", name = "Counter")] +impl Counter { + #[handler(shared, name = "get")] async fn get(&self, ctx: SharedObjectContext<'_>) -> HandlerResult { Ok(ctx.get::(COUNT).await?.unwrap_or(0)) } + #[handler(name = "add")] async fn add( &self, ctx: ObjectContext<'_>, @@ -50,11 +39,13 @@ impl Counter for CounterImpl { .into()) } + #[handler(name = "reset")] async fn reset(&self, ctx: ObjectContext<'_>) -> HandlerResult<()> { ctx.clear(COUNT); Ok(()) } + #[handler(name = "addThenFail")] async fn add_then_fail(&self, ctx: ObjectContext<'_>, val: u64) -> HandlerResult<()> { let current = ctx.get::(COUNT).await?.unwrap_or(0); let new = current + val; diff --git a/test-services/src/failing.rs b/test-services/src/failing.rs index 197a7af..58326b0 100644 --- a/test-services/src/failing.rs +++ b/test-services/src/failing.rs @@ -4,42 +4,25 @@ use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::Arc; use std::time::Duration; -#[restate_sdk::object] -#[name = "Failing"] -pub(crate) trait Failing { - #[name = "terminallyFailingCall"] - async fn terminally_failing_call(error_message: String) -> HandlerResult<()>; - #[name = "callTerminallyFailingCall"] - async fn call_terminally_failing_call(error_message: String) -> HandlerResult; - #[name = "failingCallWithEventualSuccess"] - async fn failing_call_with_eventual_success() -> HandlerResult; - #[name = "terminallyFailingSideEffect"] - async fn terminally_failing_side_effect(error_message: String) -> HandlerResult<()>; - #[name = "sideEffectSucceedsAfterGivenAttempts"] - async fn side_effect_succeeds_after_given_attempts(minimum_attempts: i32) - -> HandlerResult; - #[name = "sideEffectFailsAfterGivenAttempts"] - async fn side_effect_fails_after_given_attempts( - retry_policy_max_retry_count: i32, - ) -> HandlerResult; -} - #[derive(Clone, Default)] -pub(crate) struct FailingImpl { +pub(crate) struct Failing { eventual_success_calls: Arc, eventual_success_side_effects: Arc, eventual_failure_side_effects: Arc, } -impl Failing for FailingImpl { +#[restate_sdk::object(vis = "pub(crate)", name = "Failing")] +impl Failing { + #[handler(name = "terminallyFailingCall")] async fn terminally_failing_call( &self, - _: ObjectContext<'_>, + _ctx: ObjectContext<'_>, error_message: String, ) -> HandlerResult<()> { Err(TerminalError::new(error_message).into()) } + #[handler(name = "callTerminallyFailingCall")] async fn call_terminally_failing_call( &self, mut context: ObjectContext<'_>, @@ -55,7 +38,11 @@ impl Failing for FailingImpl { unreachable!("This should be unreachable") } - async fn failing_call_with_eventual_success(&self, _: ObjectContext<'_>) -> HandlerResult { + #[handler(name = "failingCallWithEventualSuccess")] + async fn failing_call_with_eventual_success( + &self, + _ctx: ObjectContext<'_>, + ) -> HandlerResult { let current_attempt = self.eventual_success_calls.fetch_add(1, Ordering::SeqCst) + 1; if current_attempt >= 4 { @@ -66,6 +53,7 @@ impl Failing for FailingImpl { } } + #[handler(name = "terminallyFailingSideEffect")] async fn terminally_failing_side_effect( &self, context: ObjectContext<'_>, @@ -78,6 +66,7 @@ impl Failing for FailingImpl { unreachable!("This should be unreachable") } + #[handler(name = "sideEffectSucceedsAfterGivenAttempts")] async fn side_effect_succeeds_after_given_attempts( &self, context: ObjectContext<'_>, @@ -106,6 +95,7 @@ impl Failing for FailingImpl { Ok(success_attempt) } + #[handler(name = "sideEffectFailsAfterGivenAttempts")] async fn side_effect_fails_after_given_attempts( &self, context: ObjectContext<'_>, diff --git a/test-services/src/kill_test.rs b/test-services/src/kill_test.rs index 4a87faf..a7a960c 100644 --- a/test-services/src/kill_test.rs +++ b/test-services/src/kill_test.rs @@ -1,16 +1,11 @@ use crate::awakeable_holder; use restate_sdk::prelude::*; -#[restate_sdk::object] -#[name = "KillTestRunner"] -pub(crate) trait KillTestRunner { - #[name = "startCallTree"] - async fn start_call_tree() -> HandlerResult<()>; -} - -pub(crate) struct KillTestRunnerImpl; +pub(crate) struct KillTestRunner; -impl KillTestRunner for KillTestRunnerImpl { +#[restate_sdk::object(vis = "pub(crate)", name = "KillTestRunner")] +impl KillTestRunner { + #[handler(name = "startCallTree")] async fn start_call_tree(&self, context: ObjectContext<'_>) -> HandlerResult<()> { context .object_client::(context.key()) @@ -21,18 +16,11 @@ impl KillTestRunner for KillTestRunnerImpl { } } -#[restate_sdk::object] -#[name = "KillTestSingleton"] -pub(crate) trait KillTestSingleton { - #[name = "recursiveCall"] - async fn recursive_call() -> HandlerResult<()>; - #[name = "isUnlocked"] - async fn is_unlocked() -> HandlerResult<()>; -} - -pub(crate) struct KillTestSingletonImpl; +pub(crate) struct KillTestSingleton; -impl KillTestSingleton for KillTestSingletonImpl { +#[restate_sdk::object(vis = "pub(crate)", name = "KillTestSingleton")] +impl KillTestSingleton { + #[handler(name = "recursiveCall")] async fn recursive_call(&self, context: ObjectContext<'_>) -> HandlerResult<()> { let awakeable_holder_client = context.object_client::(context.key()); @@ -50,7 +38,8 @@ impl KillTestSingleton for KillTestSingletonImpl { Ok(()) } - async fn is_unlocked(&self, _: ObjectContext<'_>) -> HandlerResult<()> { + #[handler(name = "isUnlocked")] + async fn is_unlocked(&self, _ctx: ObjectContext<'_>) -> HandlerResult<()> { // no-op Ok(()) } diff --git a/test-services/src/list_object.rs b/test-services/src/list_object.rs index 7988c9f..3cd99d1 100644 --- a/test-services/src/list_object.rs +++ b/test-services/src/list_object.rs @@ -1,21 +1,12 @@ use restate_sdk::prelude::*; -#[restate_sdk::object] -#[name = "ListObject"] -pub(crate) trait ListObject { - #[name = "append"] - async fn append(value: String) -> HandlerResult<()>; - #[name = "get"] - async fn get() -> HandlerResult>>; - #[name = "clear"] - async fn clear() -> HandlerResult>>; -} - -pub(crate) struct ListObjectImpl; +pub(crate) struct ListObject; const LIST: &str = "list"; -impl ListObject for ListObjectImpl { +#[restate_sdk::object(vis = "pub(crate)", name = "ListObject")] +impl ListObject { + #[handler(name = "append")] async fn append(&self, ctx: ObjectContext<'_>, value: String) -> HandlerResult<()> { let mut list = ctx .get::>>(LIST) @@ -27,6 +18,7 @@ impl ListObject for ListObjectImpl { Ok(()) } + #[handler(name = "get")] async fn get(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { Ok(ctx .get::>>(LIST) @@ -34,6 +26,7 @@ impl ListObject for ListObjectImpl { .unwrap_or_default()) } + #[handler(name = "clear")] async fn clear(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { let get = ctx .get::>>(LIST) diff --git a/test-services/src/main.rs b/test-services/src/main.rs index bde49a9..249e4a6 100644 --- a/test-services/src/main.rs +++ b/test-services/src/main.rs @@ -23,64 +23,62 @@ async fn main() { let mut builder = Endpoint::builder(); if services == "*" || services.contains("Counter") { - builder = builder.bind(counter::Counter::serve(counter::CounterImpl)) + builder = builder.bind(counter::Counter::serve(counter::Counter)) } if services == "*" || services.contains("Proxy") { - builder = builder.bind(proxy::Proxy::serve(proxy::ProxyImpl)) + builder = builder.bind(proxy::Proxy::serve(proxy::Proxy)) } if services == "*" || services.contains("MapObject") { - builder = builder.bind(map_object::MapObject::serve(map_object::MapObjectImpl)) + builder = builder.bind(map_object::MapObject::serve(map_object::MapObject)) } if services == "*" || services.contains("ListObject") { - builder = builder.bind(list_object::ListObject::serve(list_object::ListObjectImpl)) + builder = builder.bind(list_object::ListObject::serve(list_object::ListObject)) } if services == "*" || services.contains("AwakeableHolder") { builder = builder.bind(awakeable_holder::AwakeableHolder::serve( - awakeable_holder::AwakeableHolderImpl, + awakeable_holder::AwakeableHolder, )) } if services == "*" || services.contains("BlockAndWaitWorkflow") { builder = builder.bind(block_and_wait_workflow::BlockAndWaitWorkflow::serve( - block_and_wait_workflow::BlockAndWaitWorkflowImpl, + block_and_wait_workflow::BlockAndWaitWorkflow, )) } if services == "*" || services.contains("CancelTestRunner") { builder = builder.bind(cancel_test::CancelTestRunner::serve( - cancel_test::CancelTestRunnerImpl, + cancel_test::CancelTestRunner, )) } if services == "*" || services.contains("CancelTestBlockingService") { builder = builder.bind(cancel_test::CancelTestBlockingService::serve( - cancel_test::CancelTestBlockingServiceImpl, + cancel_test::CancelTestBlockingService, )) } if services == "*" || services.contains("Failing") { - builder = builder.bind(failing::Failing::serve(failing::FailingImpl::default())) + builder = builder.bind(failing::Failing::serve(failing::Failing::default())) } if services == "*" || services.contains("KillTestRunner") { - builder = builder.bind(kill_test::KillTestRunner::serve( - kill_test::KillTestRunnerImpl, - )) + builder = builder.bind(kill_test::KillTestRunner::serve(kill_test::KillTestRunner)) } if services == "*" || services.contains("KillTestSingleton") { builder = builder.bind(kill_test::KillTestSingleton::serve( - kill_test::KillTestSingletonImpl, + kill_test::KillTestSingleton, )) } if services == "*" || services.contains("NonDeterministic") { builder = builder.bind(non_deterministic::NonDeterministic::serve( - non_deterministic::NonDeterministicImpl::default(), + non_deterministic::NonDeterministic::default(), )) } if services == "*" || services.contains("TestUtilsService") { builder = builder.bind(test_utils_service::TestUtilsService::serve( - test_utils_service::TestUtilsServiceImpl, + test_utils_service::TestUtilsService, )) } if services == "*" || services.contains("VirtualObjectCommandInterpreter") { builder = builder.bind( virtual_object_command_interpreter::VirtualObjectCommandInterpreter::serve( - virtual_object_command_interpreter::VirtualObjectCommandInterpreterImpl, + virtual_object_command_interpreter::VirtualObjectCommandInterpreter, ), ) } diff --git a/test-services/src/map_object.rs b/test-services/src/map_object.rs index cf5ab76..38917de 100644 --- a/test-services/src/map_object.rs +++ b/test-services/src/map_object.rs @@ -9,20 +9,11 @@ pub(crate) struct Entry { value: String, } -#[restate_sdk::object] -#[name = "MapObject"] -pub(crate) trait MapObject { - #[name = "set"] - async fn set(entry: Json) -> HandlerResult<()>; - #[name = "get"] - async fn get(key: String) -> HandlerResult; - #[name = "clearAll"] - async fn clear_all() -> HandlerResult>>; -} - -pub(crate) struct MapObjectImpl; +pub(crate) struct MapObject; -impl MapObject for MapObjectImpl { +#[restate_sdk::object(vis = "pub(crate)", name = "MapObject")] +impl MapObject { + #[handler(name = "set")] async fn set( &self, ctx: ObjectContext<'_>, @@ -32,10 +23,12 @@ impl MapObject for MapObjectImpl { Ok(()) } + #[handler(name = "get")] async fn get(&self, ctx: ObjectContext<'_>, key: String) -> HandlerResult { Ok(ctx.get(&key).await?.unwrap_or_default()) } + #[handler(name = "clearAll")] async fn clear_all(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { let keys = ctx.get_keys().await?; diff --git a/test-services/src/non_deterministic.rs b/test-services/src/non_deterministic.rs index 7614bf8..32feb77 100644 --- a/test-services/src/non_deterministic.rs +++ b/test-services/src/non_deterministic.rs @@ -5,26 +5,15 @@ use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex; -#[restate_sdk::object] -#[name = "NonDeterministic"] -pub(crate) trait NonDeterministic { - #[name = "eitherSleepOrCall"] - async fn either_sleep_or_call() -> HandlerResult<()>; - #[name = "callDifferentMethod"] - async fn call_different_method() -> HandlerResult<()>; - #[name = "backgroundInvokeWithDifferentTargets"] - async fn background_invoke_with_different_targets() -> HandlerResult<()>; - #[name = "setDifferentKey"] - async fn set_different_key() -> HandlerResult<()>; -} - #[derive(Clone, Default)] -pub(crate) struct NonDeterministicImpl(Arc>>); +pub(crate) struct NonDeterministic(Arc>>); const STATE_A: &str = "a"; const STATE_B: &str = "b"; -impl NonDeterministic for NonDeterministicImpl { +#[restate_sdk::object(vis = "pub(crate)", name = "NonDeterministic")] +impl NonDeterministic { + #[handler(name = "eitherSleepOrCall")] async fn either_sleep_or_call(&self, context: ObjectContext<'_>) -> HandlerResult<()> { if self.do_left_action(&context).await { context.sleep(Duration::from_millis(100)).await?; @@ -38,6 +27,7 @@ impl NonDeterministic for NonDeterministicImpl { Self::sleep_then_increment_counter(&context).await } + #[handler(name = "callDifferentMethod")] async fn call_different_method(&self, context: ObjectContext<'_>) -> HandlerResult<()> { if self.do_left_action(&context).await { context @@ -55,6 +45,7 @@ impl NonDeterministic for NonDeterministicImpl { Self::sleep_then_increment_counter(&context).await } + #[handler(name = "backgroundInvokeWithDifferentTargets")] async fn background_invoke_with_different_targets( &self, context: ObjectContext<'_>, @@ -67,6 +58,7 @@ impl NonDeterministic for NonDeterministicImpl { Self::sleep_then_increment_counter(&context).await } + #[handler(name = "setDifferentKey")] async fn set_different_key(&self, context: ObjectContext<'_>) -> HandlerResult<()> { if self.do_left_action(&context).await { context.set(STATE_A, "my-state".to_owned()); @@ -77,7 +69,7 @@ impl NonDeterministic for NonDeterministicImpl { } } -impl NonDeterministicImpl { +impl NonDeterministic { async fn do_left_action(&self, ctx: &ObjectContext<'_>) -> bool { let mut counts = self.0.lock().await; *(counts diff --git a/test-services/src/proxy.rs b/test-services/src/proxy.rs index 6b1f221..c0a3b48 100644 --- a/test-services/src/proxy.rs +++ b/test-services/src/proxy.rs @@ -41,20 +41,11 @@ pub(crate) struct ManyCallRequest { await_at_the_end: bool, } -#[restate_sdk::service] -#[name = "Proxy"] -pub(crate) trait Proxy { - #[name = "call"] - async fn call(req: Json) -> HandlerResult>>; - #[name = "oneWayCall"] - async fn one_way_call(req: Json) -> HandlerResult; - #[name = "manyCalls"] - async fn many_calls(req: Json>) -> HandlerResult<()>; -} - -pub(crate) struct ProxyImpl; +pub(crate) struct Proxy; -impl Proxy for ProxyImpl { +#[restate_sdk::service(vis = "pub(crate)", name = "Proxy")] +impl Proxy { + #[handler(name = "call")] async fn call( &self, ctx: Context<'_>, @@ -67,6 +58,7 @@ impl Proxy for ProxyImpl { Ok(request.call().await?.into()) } + #[handler(name = "oneWayCall")] async fn one_way_call( &self, ctx: Context<'_>, @@ -89,6 +81,7 @@ impl Proxy for ProxyImpl { Ok(invocation_id) } + #[handler(name = "manyCalls")] async fn many_calls( &self, ctx: Context<'_>, diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs index 0152062..6bd88dd 100644 --- a/test-services/src/test_utils_service.rs +++ b/test-services/src/test_utils_service.rs @@ -7,40 +7,26 @@ use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use std::time::Duration; -#[restate_sdk::service] -#[name = "TestUtilsService"] -pub(crate) trait TestUtilsService { - #[name = "echo"] - async fn echo(input: String) -> HandlerResult; - #[name = "uppercaseEcho"] - async fn uppercase_echo(input: String) -> HandlerResult; - #[name = "rawEcho"] - async fn raw_echo(input: Vec) -> Result, Infallible>; - #[name = "echoHeaders"] - async fn echo_headers() -> HandlerResult>>; - #[name = "sleepConcurrently"] - async fn sleep_concurrently(millis_durations: Json>) -> HandlerResult<()>; - #[name = "countExecutedSideEffects"] - async fn count_executed_side_effects(increments: u32) -> HandlerResult; - #[name = "cancelInvocation"] - async fn cancel_invocation(invocation_id: String) -> Result<(), TerminalError>; -} - -pub(crate) struct TestUtilsServiceImpl; +pub(crate) struct TestUtilsService; -impl TestUtilsService for TestUtilsServiceImpl { - async fn echo(&self, _: Context<'_>, input: String) -> HandlerResult { +#[restate_sdk::service(vis = "pub(crate)", name = "TestUtilsService")] +impl TestUtilsService { + #[handler(name = "echo")] + async fn echo(&self, _ctx: Context<'_>, input: String) -> HandlerResult { Ok(input) } - async fn uppercase_echo(&self, _: Context<'_>, input: String) -> HandlerResult { + #[handler(name = "uppercaseEcho")] + async fn uppercase_echo(&self, _ctx: Context<'_>, input: String) -> HandlerResult { Ok(input.to_ascii_uppercase()) } - async fn raw_echo(&self, _: Context<'_>, input: Vec) -> Result, Infallible> { + #[handler(name = "rawEcho")] + async fn raw_echo(&self, _ctx: Context<'_>, input: Vec) -> Result, Infallible> { Ok(input) } + #[handler(name = "echoHeaders")] async fn echo_headers( &self, context: Context<'_>, @@ -56,6 +42,7 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(headers.into()) } + #[handler(name = "sleepConcurrently")] async fn sleep_concurrently( &self, context: Context<'_>, @@ -74,6 +61,7 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(()) } + #[handler(name = "countExecutedSideEffects")] async fn count_executed_side_effects( &self, context: Context<'_>, @@ -94,6 +82,7 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(counter.load(Ordering::SeqCst) as u32) } + #[handler(name = "cancelInvocation")] async fn cancel_invocation( &self, ctx: Context<'_>, diff --git a/test-services/src/virtual_object_command_interpreter.rs b/test-services/src/virtual_object_command_interpreter.rs index d401c91..4620154 100644 --- a/test-services/src/virtual_object_command_interpreter.rs +++ b/test-services/src/virtual_object_command_interpreter.rs @@ -65,32 +65,11 @@ pub(crate) struct RejectAwakeable { reason: String, } -#[restate_sdk::object] -#[name = "VirtualObjectCommandInterpreter"] -pub(crate) trait VirtualObjectCommandInterpreter { - #[name = "interpretCommands"] - async fn interpret_commands(req: Json) -> HandlerResult; +pub(crate) struct VirtualObjectCommandInterpreter; - #[name = "resolveAwakeable"] - #[shared] - async fn resolve_awakeable(req: Json) -> HandlerResult<()>; - - #[name = "rejectAwakeable"] - #[shared] - async fn reject_awakeable(req: Json) -> HandlerResult<()>; - - #[name = "hasAwakeable"] - #[shared] - async fn has_awakeable(awakeable_key: String) -> HandlerResult; - - #[name = "getResults"] - #[shared] - async fn get_results() -> HandlerResult>>; -} - -pub(crate) struct VirtualObjectCommandInterpreterImpl; - -impl VirtualObjectCommandInterpreter for VirtualObjectCommandInterpreterImpl { +#[restate_sdk::object(vis = "pub(crate)", name = "VirtualObjectCommandInterpreter")] +impl VirtualObjectCommandInterpreter { + #[handler(name = "interpretCommands")] async fn interpret_commands( &self, context: ObjectContext<'_>, @@ -193,6 +172,7 @@ impl VirtualObjectCommandInterpreter for VirtualObjectCommandInterpreterImpl { Ok(last_result) } + #[handler(name = "resolveAwakeable", shared)] async fn resolve_awakeable( &self, context: SharedObjectContext<'_>, @@ -216,6 +196,7 @@ impl VirtualObjectCommandInterpreter for VirtualObjectCommandInterpreterImpl { Ok(()) } + #[handler(name = "rejectAwakeable", shared)] async fn reject_awakeable( &self, context: SharedObjectContext<'_>, @@ -239,6 +220,7 @@ impl VirtualObjectCommandInterpreter for VirtualObjectCommandInterpreterImpl { Ok(()) } + #[handler(name = "hasAwakeable", shared)] async fn has_awakeable( &self, context: SharedObjectContext<'_>, @@ -250,6 +232,7 @@ impl VirtualObjectCommandInterpreter for VirtualObjectCommandInterpreterImpl { .is_some()) } + #[handler(name = "getResults", shared)] async fn get_results( &self, context: SharedObjectContext<'_>, diff --git a/tests/service.rs b/tests/service.rs index 55bdec1..9271bcc 100644 --- a/tests/service.rs +++ b/tests/service.rs @@ -1,48 +1,94 @@ use restate_sdk::prelude::*; // Should compile -#[restate_sdk::service] -trait MyService { - async fn my_handler(input: String) -> HandlerResult; - async fn no_input() -> HandlerResult; +pub(crate) struct MyService; - async fn no_output() -> HandlerResult<()>; +#[allow(dead_code)] +#[restate_sdk::service(vis = "pub(crate)")] +impl MyService { + #[handler] + async fn my_handler(&self, _ctx: Context<'_>, _input: String) -> HandlerResult { + unimplemented!() + } - async fn no_input_no_output() -> HandlerResult<()>; + #[handler] + async fn no_input(&self, _ctx: Context<'_>) -> HandlerResult { + unimplemented!() + } - async fn std_result() -> Result<(), std::io::Error>; + #[handler] + async fn no_output(&self, _ctx: Context<'_>) -> HandlerResult<()> { + unimplemented!() + } - async fn std_result_with_terminal_error() -> Result<(), TerminalError>; + #[handler] + async fn no_input_no_output(&self, _ctx: Context<'_>) -> HandlerResult<()> { + unimplemented!() + } - async fn std_result_with_handler_error() -> Result<(), HandlerError>; -} + #[handler] + async fn std_result(&self, _ctx: Context<'_>) -> Result<(), std::io::Error> { + unimplemented!() + } + + #[handler] + async fn std_result_with_terminal_error(&self, _ctx: Context<'_>) -> Result<(), TerminalError> { + unimplemented!() + } -#[restate_sdk::object] -trait MyObject { - async fn my_handler(input: String) -> HandlerResult; - #[shared] - async fn my_shared_handler(input: String) -> HandlerResult; + #[handler] + async fn std_result_with_handler_error(&self, _ctx: Context<'_>) -> Result<(), HandlerError> { + unimplemented!() + } } -#[restate_sdk::workflow] -trait MyWorkflow { - async fn my_handler(input: String) -> HandlerResult; - #[shared] - async fn my_shared_handler(input: String) -> HandlerResult; +pub(crate) struct MyObject; + +#[allow(dead_code)] +#[restate_sdk::object(vis = "pub(crate)")] +impl MyObject { + #[handler] + async fn my_handler(&self, _ctx: ObjectContext<'_>, _input: String) -> HandlerResult { + unimplemented!() + } + + #[handler(shared)] + async fn my_shared_handler( + &self, + _ctx: SharedObjectContext<'_>, + _input: String, + ) -> HandlerResult { + unimplemented!() + } } -#[restate_sdk::service] -#[name = "myRenamedService"] -trait MyRenamedService { - #[name = "myRenamedHandler"] - async fn my_handler() -> HandlerResult<()>; +pub(crate) struct MyWorkflow; + +#[allow(dead_code)] +#[restate_sdk::workflow(vis = "pub(crate)")] +impl MyWorkflow { + #[handler] + async fn my_handler(&self, _ctx: WorkflowContext<'_>, _input: String) -> HandlerResult { + unimplemented!() + } + + #[handler(shared)] + async fn my_shared_handler( + &self, + _ctx: SharedWorkflowContext<'_>, + _input: String, + ) -> HandlerResult { + unimplemented!() + } } -struct MyRenamedServiceImpl; +pub(crate) struct MyRenamedService; -impl MyRenamedService for MyRenamedServiceImpl { - async fn my_handler(&self, _: Context<'_>) -> HandlerResult<()> { +#[restate_sdk::service(vis = "pub(crate)", name = "myRenamedService")] +impl MyRenamedService { + #[handler(name = "myRenamedHandler")] + async fn my_handler(&self, _ctx: Context<'_>) -> HandlerResult<()> { Ok(()) } } @@ -51,7 +97,7 @@ impl MyRenamedService for MyRenamedServiceImpl { fn renamed_service_handler() { use restate_sdk::service::Discoverable; - let discovery = ServeMyRenamedService::::discover(); + let discovery = ServeMyRenamedService::::discover(); assert_eq!(discovery.name.to_string(), "myRenamedService"); assert_eq!(discovery.handlers[0].name.to_string(), "myRenamedHandler"); } diff --git a/tests/ui/shared_handler_in_service.rs b/tests/ui/shared_handler_in_service.rs index ef98a45..e6cd57d 100644 --- a/tests/ui/shared_handler_in_service.rs +++ b/tests/ui/shared_handler_in_service.rs @@ -1,15 +1,11 @@ use restate_sdk::prelude::*; -#[restate_sdk::service] -trait SharedHandlerInService { - #[shared] - async fn my_handler() -> HandlerResult<()>; -} - -struct SharedHandlerInServiceImpl; +struct SharedHandlerInService; -impl SharedHandlerInService for SharedHandlerInServiceImpl { - async fn my_handler(&self, _: Context<'_>) -> HandlerResult<()> { +#[restate_sdk::service] +impl SharedHandlerInService { + #[handler(shared)] + async fn my_handler(&self, _ctx: Context<'_>) -> HandlerResult<()> { Ok(()) } } @@ -19,9 +15,9 @@ async fn main() { tracing_subscriber::fmt::init(); HttpServer::new( Endpoint::builder() - .with_service(SharedHandlerInServiceImpl.serve()) + .with_service(SharedHandlerInService.serve()) .build(), ) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) .await; -} \ No newline at end of file +} diff --git a/tests/ui/shared_handler_in_service.stderr b/tests/ui/shared_handler_in_service.stderr index 6fd163a..44a4386 100644 --- a/tests/ui/shared_handler_in_service.stderr +++ b/tests/ui/shared_handler_in_service.stderr @@ -1,11 +1,23 @@ -error: Service handlers cannot be annotated with #[shared] - --> tests/ui/shared_handler_in_service.rs:6:14 +error: Service handlers cannot be annotated with #[handler(shared)] + --> tests/ui/shared_handler_in_service.rs:7:15 | -6 | async fn my_handler() -> HandlerResult<()>; - | ^^^^^^^^^^ +7 | #[handler(shared)] + | ^^^^^^ -error[E0405]: cannot find trait `SharedHandlerInService` in this scope - --> tests/ui/shared_handler_in_service.rs:11:6 +error[E0599]: no method named `with_service` found for struct `restate_sdk::endpoint::Builder` in the current scope + --> tests/ui/shared_handler_in_service.rs:18:14 | -11 | impl SharedHandlerInService for SharedHandlerInServiceImpl { - | ^^^^^^^^^^^^^^^^^^^^^^ not found in this scope +17 | / Endpoint::builder() +18 | | .with_service(SharedHandlerInService.serve()) + | | -^^^^^^^^^^^^ method not found in `Builder` + | |_____________| + | + +error[E0599]: no method named `serve` found for struct `SharedHandlerInService` in the current scope + --> tests/ui/shared_handler_in_service.rs:18:50 + | +3 | struct SharedHandlerInService; + | ----------------------------- method `serve` not found for this struct +... +18 | .with_service(SharedHandlerInService.serve()) + | ^^^^^ method not found in `SharedHandlerInService`