Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 69 additions & 8 deletions crates/taskito-core/src/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub use crate::job::Job;
pub struct SchedulerConfig {
/// Interval between scheduler poll cycles.
pub poll_interval: Duration,
/// Priority boost per second of wait time. `None` disables aging.
/// Example: `Some(1)` boosts priority by 1 per second of wait.
pub aging_factor: Option<i64>,
/// Reap stale jobs every N iterations.
pub reap_interval: u32,
/// Check periodic tasks every N iterations.
Expand All @@ -36,6 +39,7 @@ impl Default for SchedulerConfig {
fn default() -> Self {
Self {
poll_interval: Duration::from_millis(50),
aging_factor: None,
reap_interval: 100,
periodic_check_interval: 60,
cleanup_interval: 1200,
Expand Down Expand Up @@ -116,6 +120,34 @@ pub struct QueueConfig {
pub max_concurrent: Option<i32>,
}

/// In-memory cache of average task execution duration.
/// Updated on every `handle_result()` — no DB queries needed.
struct TaskDurationCache {
durations: HashMap<String, (u64, i64)>, // (count, sum_ns)
}

impl TaskDurationCache {
fn new() -> Self {
Self {
durations: HashMap::new(),
}
}

fn record(&mut self, task_name: &str, wall_time_ns: i64) {
let entry = self.durations.entry(task_name.to_string()).or_default();
entry.0 += 1;
entry.1 = entry.1.saturating_add(wall_time_ns);
}

#[allow(dead_code)]
fn avg_ns(&self, task_name: &str) -> Option<i64> {
self.durations
.get(task_name)
.filter(|(count, _)| *count > 0)
.map(|(count, sum)| sum / *count as i64)
}
}

/// The central scheduler that coordinates job dispatch, retries, rate limiting, and circuit breakers.
pub struct Scheduler {
storage: StorageBackend,
Expand All @@ -129,6 +161,7 @@ pub struct Scheduler {
shutdown: Arc<Notify>,
paused_cache: Mutex<(HashSet<String>, Instant)>,
namespace: Option<String>,
duration_cache: Mutex<TaskDurationCache>,
}

/// Counters for tick-based scheduling of periodic maintenance tasks.
Expand Down Expand Up @@ -162,6 +195,7 @@ impl Scheduler {
shutdown: Arc::new(Notify::new()),
paused_cache: Mutex::new((HashSet::new(), Instant::now())),
namespace,
duration_cache: Mutex::new(TaskDurationCache::new()),
}
}

Expand Down Expand Up @@ -192,24 +226,44 @@ impl Scheduler {

/// Run the scheduler loop. Polls for ready jobs and dispatches them
/// to the worker pool via the provided channel.
///
/// Uses adaptive polling: starts at `poll_interval`, backs off
/// exponentially (up to 1s) when no jobs are found, resets immediately
/// when a job is dispatched.
pub async fn run(&self, job_tx: tokio::sync::mpsc::Sender<Job>) {
let mut counters = TickCounters::default();
let base_interval = self.config.poll_interval;
let max_interval = Duration::from_millis(200);
let mut current_interval = base_interval;

loop {
tokio::select! {
_ = self.shutdown.notified() => break,
_ = tokio::time::sleep(self.config.poll_interval) => {}
_ = tokio::time::sleep(current_interval) => {}
}

self.tick(&job_tx, &mut counters);
let had_work = self.tick(&job_tx, &mut counters);
if had_work {
current_interval = base_interval;
} else {
current_interval = (current_interval * 2).min(max_interval);
}
}
}

/// Execute one iteration of the scheduler loop.
fn tick(&self, job_tx: &tokio::sync::mpsc::Sender<Job>, counters: &mut TickCounters) {
if let Err(e) = self.try_dispatch(job_tx) {
error!("scheduler error: {e}");
}
/// Returns true if any work was done (job dispatched or periodic task enqueued),
/// which resets the adaptive poll interval.
fn tick(&self, job_tx: &tokio::sync::mpsc::Sender<Job>, counters: &mut TickCounters) -> bool {
let dispatched = match self.try_dispatch(job_tx) {
Ok(d) => d,
Err(e) => {
error!("scheduler error: {e}");
false
}
};

let mut had_maintenance = false;

counters.reap += 1;
counters.periodic += 1;
Expand All @@ -225,8 +279,13 @@ impl Scheduler {
.periodic
.is_multiple_of(self.config.periodic_check_interval)
{
if let Err(e) = self.check_periodic() {
error!("periodic check error: {e}");
match self.check_periodic() {
Ok(()) => {
// Periodic tasks may have been enqueued — reset polling
// so the next tick picks them up quickly.
had_maintenance = true;
}
Err(e) => error!("periodic check error: {e}"),
}
}

Expand All @@ -238,6 +297,8 @@ impl Scheduler {
error!("auto-cleanup error: {e}");
}
}

dispatched || had_maintenance
}
}

Expand Down
10 changes: 10 additions & 0 deletions crates/taskito-core/src/scheduler/result_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ impl Scheduler {
error!("circuit breaker error for {task_name}: {e}");
}

// Update in-memory duration cache for smart scheduling
if let Ok(mut cache) = self.duration_cache.lock() {
cache.record(task_name, wall_time_ns);
}

Ok(ResultOutcome::Success {
job_id,
task_name: task_name.clone(),
Expand Down Expand Up @@ -71,6 +76,11 @@ impl Scheduler {
log::error!("circuit breaker error for {task_name}: {e}");
}

// Update in-memory duration cache for smart scheduling
if let Ok(mut cache) = self.duration_cache.lock() {
cache.record(&task_name, wall_time_ns);
}

// Look up the job to get the queue name for middleware context
let queue = self
.storage
Expand Down
33 changes: 33 additions & 0 deletions crates/taskito-python/src/prefork/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@ pub fn least_loaded(in_flight_counts: &[u32]) -> usize {
.unwrap_or(0)
}

/// Selects the child with the lowest weighted load.
///
/// Score = in_flight_count * avg_task_duration_ns. A worker with 1 slow job
/// scores higher than one with 3 fast jobs, enabling better load distribution
/// across heterogeneous workloads.
///
/// Falls back to `least_loaded` when `avg_duration_ns` is 0.
#[allow(dead_code)]
pub fn weighted_least_loaded(in_flight_counts: &[u32], avg_duration_ns: i64) -> usize {
if avg_duration_ns <= 0 {
return least_loaded(in_flight_counts);
}
in_flight_counts
.iter()
.enumerate()
.min_by_key(|(_, &count)| count as i64 * avg_duration_ns)
.map(|(idx, _)| idx)
.unwrap_or(0)
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -29,4 +49,17 @@ mod tests {
fn test_least_loaded_single() {
assert_eq!(least_loaded(&[5]), 0);
}

#[test]
fn test_weighted_picks_lowest_score() {
// Worker 0: 2 in-flight * 100ns = 200
// Worker 1: 1 in-flight * 100ns = 100 ← pick this
// Worker 2: 3 in-flight * 100ns = 300
assert_eq!(weighted_least_loaded(&[2, 1, 3], 100), 1);
}

#[test]
fn test_weighted_falls_back_on_zero_duration() {
assert_eq!(weighted_least_loaded(&[3, 0, 2], 0), 1); // same as least_loaded
}
}
Loading