diff --git a/src/lib.rs b/src/lib.rs index cd977a3..5295ab3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,9 @@ use droppable_future::*; mod task_identifier; pub use task_identifier::*; +mod split_ticked_async_executor; +pub use split_ticked_async_executor::*; + mod ticked_async_executor; pub use ticked_async_executor::*; diff --git a/src/split_ticked_async_executor.rs b/src/split_ticked_async_executor.rs new file mode 100644 index 0000000..002b3c7 --- /dev/null +++ b/src/split_ticked_async_executor.rs @@ -0,0 +1,165 @@ +use std::{ + future::Future, + sync::{ + atomic::{AtomicUsize, Ordering}, + mpsc, Arc, + }, +}; + +use crate::{DroppableFuture, TaskIdentifier, TickedTimer}; + +#[derive(Debug)] +pub enum TaskState { + Spawn(TaskIdentifier), + Wake(TaskIdentifier), + Tick(TaskIdentifier, f64), + Drop(TaskIdentifier), +} + +pub type Task = async_task::Task; +type Payload = (TaskIdentifier, async_task::Runnable); + +pub fn new_split_ticked_async_executor( + observer: O, +) -> (TickedAsyncExecutorSpawner, TickedAsyncExecutorTicker) +where + O: Fn(TaskState) + Clone + Send + Sync + 'static, +{ + let (tx_channel, rx_channel) = mpsc::channel(); + let num_woken_tasks = Arc::new(AtomicUsize::new(0)); + let num_spawned_tasks = Arc::new(AtomicUsize::new(0)); + let (tx_tick_event, rx_tick_event) = tokio::sync::watch::channel(1.0); + let spawner = TickedAsyncExecutorSpawner { + tx_channel, + num_woken_tasks: num_woken_tasks.clone(), + num_spawned_tasks: num_spawned_tasks.clone(), + observer: observer.clone(), + rx_tick_event, + }; + let ticker = TickedAsyncExecutorTicker { + rx_channel, + num_woken_tasks, + num_spawned_tasks, + observer, + tx_tick_event, + }; + (spawner, ticker) +} + +pub struct TickedAsyncExecutorSpawner { + tx_channel: mpsc::Sender, + num_woken_tasks: Arc, + + num_spawned_tasks: Arc, + // TODO, Or we need a Single Producer - Multi Consumer channel i.e Broadcast channel + // Broadcast recv channel should be notified when there are new messages in the queue + // Broadcast channel must also be able to remove older/stale messages (like a RingBuffer) + observer: O, + rx_tick_event: tokio::sync::watch::Receiver, +} + +impl TickedAsyncExecutorSpawner +where + O: Fn(TaskState) + Clone + Send + Sync + 'static, +{ + pub fn spawn_local( + &self, + identifier: impl Into, + future: impl Future + 'static, + ) -> Task + where + T: 'static, + { + let identifier = identifier.into(); + let future = self.droppable_future(identifier.clone(), future); + let schedule = self.runnable_schedule_cb(identifier); + let (runnable, task) = async_task::spawn_local(future, schedule); + runnable.schedule(); + task + } + + pub fn create_timer(&self) -> TickedTimer { + let tick_recv = self.rx_tick_event.clone(); + TickedTimer { tick_recv } + } + + pub fn tick_channel(&self) -> tokio::sync::watch::Receiver { + self.rx_tick_event.clone() + } + + pub fn num_tasks(&self) -> usize { + self.num_spawned_tasks.load(Ordering::Relaxed) + } + + fn droppable_future( + &self, + identifier: TaskIdentifier, + future: F, + ) -> DroppableFuture + where + F: Future, + { + let observer = self.observer.clone(); + + // Spawn Task + self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); + observer(TaskState::Spawn(identifier.clone())); + + // Droppable Future registering on_drop callback + let num_spawned_tasks = self.num_spawned_tasks.clone(); + DroppableFuture::new(future, move || { + num_spawned_tasks.fetch_sub(1, Ordering::Relaxed); + observer(TaskState::Drop(identifier.clone())); + }) + } + + fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) { + let sender = self.tx_channel.clone(); + let num_woken_tasks = self.num_woken_tasks.clone(); + let observer = self.observer.clone(); + move |runnable| { + sender.send((identifier.clone(), runnable)).unwrap_or(()); + num_woken_tasks.fetch_add(1, Ordering::Relaxed); + observer(TaskState::Wake(identifier.clone())); + } + } +} + +pub struct TickedAsyncExecutorTicker { + rx_channel: mpsc::Receiver, + num_woken_tasks: Arc, + num_spawned_tasks: Arc, + observer: O, + tx_tick_event: tokio::sync::watch::Sender, +} + +impl TickedAsyncExecutorTicker +where + O: Fn(TaskState), +{ + pub fn tick(&self, delta: f64, limit: Option) { + let _r = self.tx_tick_event.send(delta); + + let mut num_woken_tasks = self.num_woken_tasks.load(Ordering::Relaxed); + if let Some(limit) = limit { + // Woken tasks should not exceed the allowed limit + num_woken_tasks = num_woken_tasks.min(limit); + } + + self.rx_channel + .try_iter() + .take(num_woken_tasks) + .for_each(|(identifier, runnable)| { + (self.observer)(TaskState::Tick(identifier, delta)); + runnable.run(); + }); + self.num_woken_tasks + .fetch_sub(num_woken_tasks, Ordering::Relaxed); + } + + pub fn wait_till_completed(&self, constant_delta: f64) { + while self.num_spawned_tasks.load(Ordering::Relaxed) != 0 { + self.tick(constant_delta, None); + } + } +} diff --git a/src/ticked_async_executor.rs b/src/ticked_async_executor.rs index 9af2875..525ab10 100644 --- a/src/ticked_async_executor.rs +++ b/src/ticked_async_executor.rs @@ -1,36 +1,13 @@ -use std::{ - future::Future, - sync::{ - atomic::{AtomicUsize, Ordering}, - mpsc, Arc, - }, -}; - -use crate::{DroppableFuture, TaskIdentifier, TickedTimer}; - -#[derive(Debug)] -pub enum TaskState { - Spawn(TaskIdentifier), - Wake(TaskIdentifier), - Tick(TaskIdentifier, f64), - Drop(TaskIdentifier), -} +use std::future::Future; -pub type Task = async_task::Task; -type Payload = (TaskIdentifier, async_task::Runnable); +use crate::{ + new_split_ticked_async_executor, Task, TaskIdentifier, TaskState, TickedAsyncExecutorSpawner, + TickedAsyncExecutorTicker, TickedTimer, +}; pub struct TickedAsyncExecutor { - channel: (mpsc::Sender, mpsc::Receiver), - num_woken_tasks: Arc, - - num_spawned_tasks: Arc, - - // TODO, Or we need a Single Producer - Multi Consumer channel i.e Broadcast channel - // Broadcast recv channel should be notified when there are new messages in the queue - // Broadcast channel must also be able to remove older/stale messages (like a RingBuffer) - observer: O, - - tick_event: tokio::sync::watch::Sender, + spawner: TickedAsyncExecutorSpawner, + ticker: TickedAsyncExecutorTicker, } impl Default for TickedAsyncExecutor { @@ -44,13 +21,8 @@ where O: Fn(TaskState) + Clone + Send + Sync + 'static, { pub fn new(observer: O) -> Self { - Self { - channel: mpsc::channel(), - num_woken_tasks: Arc::new(AtomicUsize::new(0)), - num_spawned_tasks: Arc::new(AtomicUsize::new(0)), - observer, - tick_event: tokio::sync::watch::channel(1.0).0, - } + let (spawner, ticker) = new_split_ticked_async_executor(observer); + Self { spawner, ticker } } pub fn spawn_local( @@ -61,16 +33,11 @@ where where T: 'static, { - let identifier = identifier.into(); - let future = self.droppable_future(identifier.clone(), future); - let schedule = self.runnable_schedule_cb(identifier); - let (runnable, task) = async_task::spawn_local(future, schedule); - runnable.schedule(); - task + self.spawner.spawn_local(identifier, future) } pub fn num_tasks(&self) -> usize { - self.num_spawned_tasks.load(Ordering::Relaxed) + self.spawner.num_tasks() } /// Run the woken tasks once @@ -81,72 +48,25 @@ where /// `limit` is used to limit the number of woken tasks run per tick /// - None would imply that there is no limit (all woken tasks would run) /// - Some(limit) would imply that [0..limit] woken tasks would run, - /// even if more tasks are woken. + /// even if more tasks are woken. /// /// Tick is !Sync i.e cannot be invoked from multiple threads /// /// NOTE: Will not run tasks that are woken/scheduled immediately after `Runnable::run` pub fn tick(&self, delta: f64, limit: Option) { - let _r = self.tick_event.send(delta); - - let mut num_woken_tasks = self.num_woken_tasks.load(Ordering::Relaxed); - if let Some(limit) = limit { - // Woken tasks should not exceed the allowed limit - num_woken_tasks = num_woken_tasks.min(limit); - } - - self.channel - .1 - .try_iter() - .take(num_woken_tasks) - .for_each(|(identifier, runnable)| { - (self.observer)(TaskState::Tick(identifier, delta)); - runnable.run(); - }); - self.num_woken_tasks - .fetch_sub(num_woken_tasks, Ordering::Relaxed); + self.ticker.tick(delta, limit); } pub fn create_timer(&self) -> TickedTimer { - let tick_recv = self.tick_event.subscribe(); - TickedTimer { tick_recv } + self.spawner.create_timer() } pub fn tick_channel(&self) -> tokio::sync::watch::Receiver { - self.tick_event.subscribe() - } - - fn droppable_future( - &self, - identifier: TaskIdentifier, - future: F, - ) -> DroppableFuture - where - F: Future, - { - let observer = self.observer.clone(); - - // Spawn Task - self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); - observer(TaskState::Spawn(identifier.clone())); - - // Droppable Future registering on_drop callback - let num_spawned_tasks = self.num_spawned_tasks.clone(); - DroppableFuture::new(future, move || { - num_spawned_tasks.fetch_sub(1, Ordering::Relaxed); - observer(TaskState::Drop(identifier.clone())); - }) + self.spawner.tick_channel() } - fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) { - let sender = self.channel.0.clone(); - let num_woken_tasks = self.num_woken_tasks.clone(); - let observer = self.observer.clone(); - move |runnable| { - sender.send((identifier.clone(), runnable)).unwrap_or(()); - num_woken_tasks.fetch_add(1, Ordering::Relaxed); - observer(TaskState::Wake(identifier.clone())); - } + pub fn wait_till_completed(&self, delta: f64) { + self.ticker.wait_till_completed(delta); } } @@ -220,9 +140,7 @@ mod tests { assert_eq!(executor.num_tasks(), 3); // Since we have cancelled the tasks above, the loops should eventually end - while executor.num_tasks() != 0 { - executor.tick(DELTA, None); - } + executor.wait_till_completed(DELTA); } #[test] @@ -311,8 +229,8 @@ mod tests { } for i in 0..10 { - let woken_tasks = executor.num_woken_tasks.load(Ordering::Relaxed); - assert_eq!(woken_tasks, 10 - i); + let num_tasks = executor.num_tasks(); + assert_eq!(num_tasks, 10 - i); executor.tick(0.1, Some(1)); }