Skip to content

Refactor Macros to Apply to Impl Blocks for Ergonomics (#43) #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,11 @@ Then you're ready to develop your Restate service using Rust:
```rust
use restate_sdk::prelude::*;

#[restate_sdk::service]
trait Greeter {
async fn greet(name: String) -> HandlerResult<String>;
}

struct GreeterImpl;
pub(crate) struct Greeter;

impl Greeter for GreeterImpl {
#[restate_sdk::service(vis = "pub(crate)", name = "Greetings")]
impl Greeter {
#[handler]
async fn greet(&self, _: Context<'_>, name: String) -> HandlerResult<String> {
Ok(format!("Greetings {name}"))
}
Expand All @@ -50,7 +47,7 @@ async fn main() {
// tracing_subscriber::fmt::init();
HttpServer::new(
Endpoint::builder()
.with_service(GreeterImpl.serve())
.with_service(Greeter.serve())
.build(),
)
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
Expand All @@ -75,7 +72,7 @@ async fn test_container() {
.with_max_level(tracing::Level::INFO) // Set the maximum log level
.init();

let endpoint = Endpoint::builder().bind(MyServiceImpl.serve()).build();
let endpoint = Endpoint::builder().bind(MyService.serve()).build();

// simple test container intialization with default configuration
//let test_container = TestContainer::default().start(endpoint).await.unwrap();
Expand Down
22 changes: 9 additions & 13 deletions examples/counter.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,30 @@
use restate_sdk::prelude::*;

#[restate_sdk::object]
trait Counter {
#[shared]
async fn get() -> Result<u64, TerminalError>;
async fn add(val: u64) -> Result<u64, TerminalError>;
async fn increment() -> Result<u64, TerminalError>;
async fn reset() -> Result<(), TerminalError>;
}

struct CounterImpl;

const COUNT: &str = "count";

impl Counter for CounterImpl {
struct Counter;

#[restate_sdk::object]
impl Counter {
#[handler(shared)]
async fn get(&self, ctx: SharedObjectContext<'_>) -> Result<u64, TerminalError> {
Ok(ctx.get::<u64>(COUNT).await?.unwrap_or(0))
}

#[handler]
async fn add(&self, ctx: ObjectContext<'_>, val: u64) -> Result<u64, TerminalError> {
let current = ctx.get::<u64>(COUNT).await?.unwrap_or(0);
let new = current + val;
ctx.set(COUNT, new);
Ok(new)
}

#[handler]
async fn increment(&self, ctx: ObjectContext<'_>) -> Result<u64, TerminalError> {
self.add(ctx, 1).await
}

#[handler]
async fn reset(&self, ctx: ObjectContext<'_>) -> Result<(), TerminalError> {
ctx.clear(COUNT);
Ok(())
Expand All @@ -38,7 +34,7 @@ impl Counter for CounterImpl {
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
HttpServer::new(Endpoint::builder().bind(CounterImpl.serve()).build())
HttpServer::new(Endpoint::builder().bind(Counter.serve()).build())
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
.await;
}
26 changes: 10 additions & 16 deletions examples/cron.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,13 @@ use std::time::Duration;
/// ```shell
/// $ curl -v http://localhost:8080/PeriodicTask/my-periodic-task/start
/// ```
#[restate_sdk::object]
trait PeriodicTask {
/// Schedules the periodic task to start
async fn start() -> Result<(), TerminalError>;
/// Stops the periodic task
async fn stop() -> Result<(), TerminalError>;
/// Business logic of the periodic task
async fn run() -> Result<(), TerminalError>;
}

struct PeriodicTaskImpl;
struct PeriodicTask;

const ACTIVE: &str = "active";

impl PeriodicTask for PeriodicTaskImpl {
#[restate_sdk::object]
impl PeriodicTask {
#[handler]
async fn start(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> {
if context
.get::<bool>(ACTIVE)
Expand All @@ -39,21 +31,23 @@ impl PeriodicTask for PeriodicTaskImpl {
}

// Schedule the periodic task
PeriodicTaskImpl::schedule_next(&context);
PeriodicTask::schedule_next(&context);

// Mark the periodic task as active
context.set(ACTIVE, true);

Ok(())
}

#[handler]
async fn stop(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> {
// Remove the active flag
context.clear(ACTIVE);

Ok(())
}

#[handler]
async fn run(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> {
if context.get::<bool>(ACTIVE).await?.is_none() {
// Task is inactive, do nothing
Expand All @@ -64,13 +58,13 @@ impl PeriodicTask for PeriodicTaskImpl {
println!("Triggered the periodic task!");

// Schedule the periodic task
PeriodicTaskImpl::schedule_next(&context);
PeriodicTask::schedule_next(&context);

Ok(())
}
}

impl PeriodicTaskImpl {
impl PeriodicTask {
fn schedule_next(context: &ObjectContext<'_>) {
// To schedule, create a client to the callee handler (in this case, we're calling ourselves)
context
Expand All @@ -84,7 +78,7 @@ impl PeriodicTaskImpl {
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
HttpServer::new(Endpoint::builder().bind(PeriodicTaskImpl.serve()).build())
HttpServer::new(Endpoint::builder().bind(PeriodicTask.serve()).build())
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
.await;
}
16 changes: 6 additions & 10 deletions examples/failures.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
use rand::RngCore;
use restate_sdk::prelude::*;

#[restate_sdk::service]
trait FailureExample {
#[name = "doRun"]
async fn do_run() -> Result<(), TerminalError>;
}

struct FailureExampleImpl;

#[derive(Debug, thiserror::Error)]
#[error("I'm very bad, retry me")]
struct MyError;

impl FailureExample for FailureExampleImpl {
struct FailureExample;

#[restate_sdk::service]
impl FailureExample {
#[handler(name = "doRun")]
async fn do_run(&self, context: Context<'_>) -> Result<(), TerminalError> {
context
.run::<_, _, ()>(|| async move {
Expand All @@ -32,7 +28,7 @@ impl FailureExample for FailureExampleImpl {
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
HttpServer::new(Endpoint::builder().bind(FailureExampleImpl.serve()).build())
HttpServer::new(Endpoint::builder().bind(FailureExample.serve()).build())
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
.await;
}
15 changes: 6 additions & 9 deletions examples/greeter.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
use restate_sdk::prelude::*;
use std::convert::Infallible;

#[restate_sdk::service]
trait Greeter {
async fn greet(name: String) -> Result<String, Infallible>;
}
struct Greeter;

struct GreeterImpl;

impl Greeter for GreeterImpl {
async fn greet(&self, _: Context<'_>, name: String) -> Result<String, Infallible> {
#[restate_sdk::service]
impl Greeter {
#[handler]
async fn greet(&self, _ctx: Context<'_>, name: String) -> Result<String, Infallible> {
Ok(format!("Greetings {name}"))
}
}

#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
HttpServer::new(Endpoint::builder().bind(GreeterImpl.serve()).build())
HttpServer::new(Endpoint::builder().bind(Greeter.serve()).build())
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
.await;
}
13 changes: 5 additions & 8 deletions examples/run.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
use restate_sdk::prelude::*;
use std::collections::HashMap;

#[restate_sdk::service]
trait RunExample {
async fn do_run() -> Result<Json<HashMap<String, String>>, HandlerError>;
}
struct RunExample(reqwest::Client);

struct RunExampleImpl(reqwest::Client);

impl RunExample for RunExampleImpl {
#[restate_sdk::service]
impl RunExample {
#[handler]
async fn do_run(
&self,
context: Context<'_>,
Expand Down Expand Up @@ -39,7 +36,7 @@ async fn main() {
tracing_subscriber::fmt::init();
HttpServer::new(
Endpoint::builder()
.bind(RunExampleImpl(reqwest::Client::new()).serve())
.bind(RunExample(reqwest::Client::new()).serve())
.build(),
)
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
Expand Down
13 changes: 5 additions & 8 deletions examples/services/my_service.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use restate_sdk::prelude::*;

#[restate_sdk::service]
pub trait MyService {
async fn my_handler(greeting: String) -> Result<String, HandlerError>;
}

pub struct MyServiceImpl;
pub struct MyService;

impl MyService for MyServiceImpl {
#[restate_sdk::service(vis = "pub(crate)")]
impl MyService {
#[handler]
async fn my_handler(&self, _ctx: Context<'_>, greeting: String) -> Result<String, HandlerError> {
Ok(format!("{greeting}!"))
}
Expand All @@ -16,7 +13,7 @@ impl MyService for MyServiceImpl {
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
HttpServer::new(Endpoint::builder().bind(MyServiceImpl.serve()).build())
HttpServer::new(Endpoint::builder().bind(MyService.serve()).build())
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
.await;
}
17 changes: 7 additions & 10 deletions examples/services/my_virtual_object.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
use restate_sdk::prelude::*;

#[restate_sdk::object]
pub trait MyVirtualObject {
async fn my_handler(name: String) -> Result<String, HandlerError>;
#[shared]
async fn my_concurrent_handler(name: String) -> Result<String, HandlerError>;
}

pub struct MyVirtualObjectImpl;
pub struct MyVirtualObject;

impl MyVirtualObject for MyVirtualObjectImpl {
#[restate_sdk::object(vis = "pub(crate)")]
impl MyVirtualObject {
#[handler]
async fn my_handler(
&self,
ctx: ObjectContext<'_>,
greeting: String,
) -> Result<String, HandlerError> {
Ok(format!("Greetings {} {}", greeting, ctx.key()))
}

#[handler(shared)]
async fn my_concurrent_handler(
&self,
ctx: SharedObjectContext<'_>,
Expand All @@ -31,7 +28,7 @@ async fn main() {
tracing_subscriber::fmt::init();
HttpServer::new(
Endpoint::builder()
.bind(MyVirtualObjectImpl.serve())
.bind(MyVirtualObject.serve())
.build(),
)
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
Expand Down
17 changes: 7 additions & 10 deletions examples/services/my_workflow.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
use restate_sdk::prelude::*;

#[restate_sdk::workflow]
pub trait MyWorkflow {
async fn run(req: String) -> Result<String, HandlerError>;
#[shared]
async fn interact_with_workflow() -> Result<(), HandlerError>;
}

pub struct MyWorkflowImpl;
pub struct MyWorkflow;

impl MyWorkflow for MyWorkflowImpl {
#[restate_sdk::workflow(vis = "pub(crate)")]
impl MyWorkflow {
#[handler]
async fn run(&self, _ctx: WorkflowContext<'_>, _req: String) -> Result<String, HandlerError> {
// implement workflow logic here

Ok(String::from("success"))
}

#[handler(shared)]
async fn interact_with_workflow(
&self,
_ctx: SharedWorkflowContext<'_>,
Expand All @@ -29,7 +26,7 @@ impl MyWorkflow for MyWorkflowImpl {
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
HttpServer::new(Endpoint::builder().bind(MyWorkflowImpl.serve()).build())
HttpServer::new(Endpoint::builder().bind(MyWorkflow.serve()).build())
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
.await;
}
13 changes: 5 additions & 8 deletions examples/tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@ use std::time::Duration;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer};

#[restate_sdk::service]
trait Greeter {
async fn greet(name: String) -> Result<String, HandlerError>;
}
struct Greeter;

struct GreeterImpl;

impl Greeter for GreeterImpl {
#[restate_sdk::service]
impl Greeter {
#[handler]
async fn greet(&self, ctx: Context<'_>, name: String) -> Result<String, HandlerError> {
info!("Before sleep");
ctx.sleep(Duration::from_secs(61)).await?; // More than suspension timeout to trigger replay
Expand All @@ -31,7 +28,7 @@ async fn main() {
.with_filter(replay_filter),
)
.init();
HttpServer::new(Endpoint::builder().bind(GreeterImpl.serve()).build())
HttpServer::new(Endpoint::builder().bind(Greeter.serve()).build())
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
.await;
}
Loading