Skip to content

Commit 99c8a79

Browse files
[fix] small optimizations (#60)
📝 Description Fix TOCTOU race in Registry::spawn_and_register Replace 100ms polling in wait_task_removed with event-driven approach Sort AliveTracker::snapshot() output for API consistency
1 parent cdf6f20 commit 99c8a79

File tree

3 files changed

+42
-63
lines changed

3 files changed

+42
-63
lines changed

src/core/alive.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,12 @@ impl AliveTracker {
115115
/// during graceful shutdown (tasks that didn't stop within grace period).
116116
pub async fn snapshot(&self) -> Vec<String> {
117117
let state = self.state.read().await;
118-
let alive: Vec<String> = state
118+
let mut alive: Vec<String> = state
119119
.iter()
120120
.filter(|(_, ts)| ts.alive)
121121
.map(|(name, _)| name.clone())
122122
.collect();
123+
alive.sort_unstable();
123124
alive
124125
}
125126

src/core/registry.rs

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,15 @@ impl Registry {
177177
/// Spawns an actor and registers its handle.
178178
async fn spawn_and_register(&self, spec: TaskSpec) {
179179
let task_name = spec.task().name().to_string();
180-
{
181-
let tasks = self.tasks.read().await;
182-
if tasks.contains_key(&task_name) {
183-
self.bus.publish(
184-
Event::new(EventKind::TaskFailed)
185-
.with_task(task_name)
186-
.with_reason("task_already_exists"),
187-
);
188-
return;
189-
}
180+
181+
let mut tasks = self.tasks.write().await;
182+
if tasks.contains_key(&task_name) {
183+
self.bus.publish(
184+
Event::new(EventKind::TaskFailed)
185+
.with_task(task_name)
186+
.with_reason("task_already_exists"),
187+
);
188+
return;
190189
}
191190

192191
let task_token = self.runtime_token.child_token();
@@ -210,23 +209,14 @@ impl Registry {
210209
cancel: task_token,
211210
};
212211

213-
let mut tasks = self.tasks.write().await;
214212
let was_empty = tasks.is_empty();
215-
let inserted = tasks.insert(task_name.clone(), handle).is_none();
213+
tasks.insert(task_name.clone(), handle);
216214
let len_after = tasks.len();
217215
drop(tasks);
218216

219-
if inserted {
220-
self.notify_after_insert(was_empty, len_after);
221-
self.bus
222-
.publish(Event::new(EventKind::TaskAdded).with_task(task_name));
223-
} else {
224-
self.bus.publish(
225-
Event::new(EventKind::TaskFailed)
226-
.with_task(task_name)
227-
.with_reason("task_already_exists_race"),
228-
);
229-
}
217+
self.notify_after_insert(was_empty, len_after);
218+
self.bus
219+
.publish(Event::new(EventKind::TaskAdded).with_task(task_name));
230220
}
231221

232222
/// Removes a task and cancels its token.

src/core/supervisor.rs

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -390,50 +390,38 @@ impl Supervisor {
390390
wait_for: Duration,
391391
) -> Result<bool, RuntimeError> {
392392
let target = name.to_string();
393-
let start = tokio::time::Instant::now();
394-
let mut last_poll = tokio::time::Instant::now();
395-
let poll_interval = Duration::from_millis(100);
396-
397-
loop {
398-
if start.elapsed() >= wait_for {
399-
return Err(RuntimeError::TaskRemoveTimeout {
400-
name: target,
401-
timeout: wait_for,
402-
});
403-
}
404-
if last_poll.elapsed() >= poll_interval {
405-
let tasks = self.registry.list().await;
406-
if !tasks.contains(&target) {
407-
return Ok(true);
408-
}
409-
last_poll = tokio::time::Instant::now();
410-
}
411-
412-
let recv_timeout = poll_interval
413-
.checked_sub(last_poll.elapsed())
414-
.unwrap_or(Duration::from_millis(10));
415393

416-
match tokio::time::timeout(recv_timeout, rx.recv()).await {
417-
Ok(Ok(ev))
418-
if matches!(ev.kind, EventKind::TaskRemoved)
419-
&& ev.task.as_deref() == Some(&target) =>
420-
{
421-
return Ok(true);
422-
}
423-
Ok(Ok(_)) => {}
424-
Ok(Err(broadcast::error::RecvError::Closed)) => {
425-
let tasks = self.registry.list().await;
426-
return Ok(!tasks.contains(&target));
427-
}
428-
Ok(Err(broadcast::error::RecvError::Lagged(_))) => {
429-
let tasks = self.registry.list().await;
430-
if !tasks.contains(&target) {
394+
let wait_for_event = async {
395+
loop {
396+
match rx.recv().await {
397+
Ok(ev)
398+
if matches!(ev.kind, EventKind::TaskRemoved)
399+
&& ev.task.as_deref() == Some(target.as_str()) =>
400+
{
431401
return Ok(true);
432402
}
433-
last_poll = tokio::time::Instant::now();
403+
Ok(_) => {}
404+
Err(broadcast::error::RecvError::Lagged(_)) => {
405+
// We may have missed the TaskRemoved event; check the registry.
406+
let tasks = self.registry.list().await;
407+
if !tasks.contains(&target) {
408+
return Ok(true);
409+
}
410+
}
411+
Err(broadcast::error::RecvError::Closed) => {
412+
let tasks = self.registry.list().await;
413+
return Ok(!tasks.contains(&target));
414+
}
434415
}
435-
Err(_elapsed) => {}
436416
}
417+
};
418+
419+
match timeout(wait_for, wait_for_event).await {
420+
Ok(result) => result,
421+
Err(_) => Err(RuntimeError::TaskRemoveTimeout {
422+
name: target,
423+
timeout: wait_for,
424+
}),
437425
}
438426
}
439427
}

0 commit comments

Comments
 (0)