diff --git a/src/core/alive.rs b/src/core/alive.rs index 5414f0c..1c1b0a1 100644 --- a/src/core/alive.rs +++ b/src/core/alive.rs @@ -115,11 +115,12 @@ impl AliveTracker { /// during graceful shutdown (tasks that didn't stop within grace period). pub async fn snapshot(&self) -> Vec { let state = self.state.read().await; - let alive: Vec = state + let mut alive: Vec = state .iter() .filter(|(_, ts)| ts.alive) .map(|(name, _)| name.clone()) .collect(); + alive.sort_unstable(); alive } diff --git a/src/core/registry.rs b/src/core/registry.rs index 5e34b09..3c4b048 100644 --- a/src/core/registry.rs +++ b/src/core/registry.rs @@ -177,16 +177,15 @@ impl Registry { /// Spawns an actor and registers its handle. async fn spawn_and_register(&self, spec: TaskSpec) { let task_name = spec.task().name().to_string(); - { - let tasks = self.tasks.read().await; - if tasks.contains_key(&task_name) { - self.bus.publish( - Event::new(EventKind::TaskFailed) - .with_task(task_name) - .with_reason("task_already_exists"), - ); - return; - } + + let mut tasks = self.tasks.write().await; + if tasks.contains_key(&task_name) { + self.bus.publish( + Event::new(EventKind::TaskFailed) + .with_task(task_name) + .with_reason("task_already_exists"), + ); + return; } let task_token = self.runtime_token.child_token(); @@ -210,23 +209,14 @@ impl Registry { cancel: task_token, }; - let mut tasks = self.tasks.write().await; let was_empty = tasks.is_empty(); - let inserted = tasks.insert(task_name.clone(), handle).is_none(); + tasks.insert(task_name.clone(), handle); let len_after = tasks.len(); drop(tasks); - if inserted { - self.notify_after_insert(was_empty, len_after); - self.bus - .publish(Event::new(EventKind::TaskAdded).with_task(task_name)); - } else { - self.bus.publish( - Event::new(EventKind::TaskFailed) - .with_task(task_name) - .with_reason("task_already_exists_race"), - ); - } + self.notify_after_insert(was_empty, len_after); + self.bus + .publish(Event::new(EventKind::TaskAdded).with_task(task_name)); } /// Removes a task and cancels its token. diff --git a/src/core/supervisor.rs b/src/core/supervisor.rs index 49e1b5e..5b77975 100644 --- a/src/core/supervisor.rs +++ b/src/core/supervisor.rs @@ -390,50 +390,38 @@ impl Supervisor { wait_for: Duration, ) -> Result { let target = name.to_string(); - let start = tokio::time::Instant::now(); - let mut last_poll = tokio::time::Instant::now(); - let poll_interval = Duration::from_millis(100); - - loop { - if start.elapsed() >= wait_for { - return Err(RuntimeError::TaskRemoveTimeout { - name: target, - timeout: wait_for, - }); - } - if last_poll.elapsed() >= poll_interval { - let tasks = self.registry.list().await; - if !tasks.contains(&target) { - return Ok(true); - } - last_poll = tokio::time::Instant::now(); - } - - let recv_timeout = poll_interval - .checked_sub(last_poll.elapsed()) - .unwrap_or(Duration::from_millis(10)); - match tokio::time::timeout(recv_timeout, rx.recv()).await { - Ok(Ok(ev)) - if matches!(ev.kind, EventKind::TaskRemoved) - && ev.task.as_deref() == Some(&target) => - { - return Ok(true); - } - Ok(Ok(_)) => {} - Ok(Err(broadcast::error::RecvError::Closed)) => { - let tasks = self.registry.list().await; - return Ok(!tasks.contains(&target)); - } - Ok(Err(broadcast::error::RecvError::Lagged(_))) => { - let tasks = self.registry.list().await; - if !tasks.contains(&target) { + let wait_for_event = async { + loop { + match rx.recv().await { + Ok(ev) + if matches!(ev.kind, EventKind::TaskRemoved) + && ev.task.as_deref() == Some(target.as_str()) => + { return Ok(true); } - last_poll = tokio::time::Instant::now(); + Ok(_) => {} + Err(broadcast::error::RecvError::Lagged(_)) => { + // We may have missed the TaskRemoved event; check the registry. + let tasks = self.registry.list().await; + if !tasks.contains(&target) { + return Ok(true); + } + } + Err(broadcast::error::RecvError::Closed) => { + let tasks = self.registry.list().await; + return Ok(!tasks.contains(&target)); + } } - Err(_elapsed) => {} } + }; + + match timeout(wait_for, wait_for_event).await { + Ok(result) => result, + Err(_) => Err(RuntimeError::TaskRemoveTimeout { + name: target, + timeout: wait_for, + }), } } }