diff --git a/src/julia_threads.h b/src/julia_threads.h index 061eb9266e7a7..0805441a9586f 100644 --- a/src/julia_threads.h +++ b/src/julia_threads.h @@ -167,6 +167,9 @@ typedef struct _jl_tls_states_t { struct _jl_task_t *next_task; struct _jl_task_t *previous_task; struct _jl_task_t *root_task; + // The scheduler sets this to instruct the task switching code (jl_switch()) + // to wait in the scheduler for another task. + int8_t wait_in_scheduler; struct _jl_timing_block_t *timing_stack; // This is the location of our copy_stack void *stackbase; diff --git a/src/scheduler.c b/src/scheduler.c index 731a0c5146605..d7c0804a64e1a 100644 --- a/src/scheduler.c +++ b/src/scheduler.c @@ -378,7 +378,18 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q, jl_cpu_pause(); jl_ptls_t ptls = ct->ptls; - if (sleep_check_after_threshold(&start_cycles) || (ptls->tid == jl_atomic_load_relaxed(&io_loop_tid) && (!jl_atomic_load_relaxed(&_threadedregion) || wait_empty))) { + if (sleep_check_after_threshold(&start_cycles) || + (ptls->tid == jl_atomic_load_relaxed(&io_loop_tid) && + (!jl_atomic_load_relaxed(&_threadedregion) || wait_empty))) { + // The place seems empty and this thread is on its way to sleeping; + // switch to the root task to do that and instruct the task switching + // code to wait for another task. + if (ct != ptls->root_task) { + ptls->wait_in_scheduler = 1; + jl_switchto(&ptls->root_task); + return ct; + } + // acquire sleep-check lock assert(jl_atomic_load_relaxed(&ptls->sleep_check_state) == not_sleeping); jl_atomic_store_relaxed(&ptls->sleep_check_state, sleeping); diff --git a/src/task.c b/src/task.c index 068689d534a03..992b3ac18614d 100644 --- a/src/task.c +++ b/src/task.c @@ -427,6 +427,12 @@ JL_DLLEXPORT jl_task_t *jl_get_next_task(void) JL_NOTSAFEPOINT const char tsan_state_corruption[] = "TSAN state corrupted. Exiting HARD!\n"; #endif +// is `t` the root task? +static int is_root_task(jl_task_t *ct, jl_task_t *t) +{ + return ct->ptls->root_task == t; +} + JL_NO_ASAN static void ctx_switch(jl_task_t *lastt) { jl_ptls_t ptls = lastt->ptls; @@ -446,7 +452,8 @@ JL_NO_ASAN static void ctx_switch(jl_task_t *lastt) } #endif - int killed = jl_atomic_load_relaxed(&lastt->_state) != JL_TASK_STATE_RUNNABLE; + int killed = (jl_atomic_load_relaxed(&lastt->_state) != JL_TASK_STATE_RUNNABLE) && + !is_root_task(t, lastt); if (!t->ctx.started && !t->ctx.copy_stack) { // may need to allocate the stack if (t->ctx.stkbuf == NULL) { @@ -663,6 +670,7 @@ JL_DLLEXPORT void jl_switch(void) JL_NOTSAFEPOINT_LEAVE JL_NOTSAFEPOINT_ENTER { jl_task_t *ct = jl_current_task; jl_ptls_t ptls = ct->ptls; +switch_restart: jl_task_t *t = ptls->next_task; if (t == ct) { return; @@ -711,6 +719,12 @@ JL_DLLEXPORT void jl_switch(void) JL_NOTSAFEPOINT_LEAVE JL_NOTSAFEPOINT_ENTER if (other_defer_signal && !defer_signal) jl_sigint_safepoint(ptls); + if (ptls->wait_in_scheduler) { + ptls->wait_in_scheduler = 0; + jl_set_next_task(jl_task_get_next()); + goto switch_restart; + } + JL_PROBE_RT_RUN_TASK(ct); jl_gc_unsafe_leave(ptls, gc_state); }