Skip to content

Commit 39fc4ee

Browse files
jrevelstkfararslanvchuravy
authored
add Threads.foreach for convenient multithreaded Channel consumption (#34543)
Co-authored-by: Takafumi Arakaki <[email protected]> Co-authored-by: Alex Arslan <[email protected]> Co-authored-by: Valentin Churavy <[email protected]>
1 parent 150311f commit 39fc4ee

File tree

5 files changed

+98
-1
lines changed

5 files changed

+98
-1
lines changed

Diff for: NEWS.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ Language changes
1414
Compiler/Runtime improvements
1515
-----------------------------
1616

17-
1817
* All platforms can now use `@executable_path` within `jl_load_dynamic_library()`.
1918
This allows executable-relative paths to be embedded within executables on all
2019
platforms, not just MacOS, which the syntax is borrowed from. ([#35627])
@@ -33,14 +32,17 @@ Build system changes
3332

3433
New library functions
3534
---------------------
35+
3636
* New function `Base.kron!` and corresponding overloads for various matrix types for performing Kronecker product in-place. ([#31069]).
37+
* New function `Base.Threads.foreach(f, channel::Channel)` for multithreaded `Channel` consumption. ([#34543]).
3738

3839
New library features
3940
--------------------
4041

4142

4243
Standard library changes
4344
------------------------
45+
4446
* The `nextprod` function now accepts tuples and other array types for its first argument ([#35791]).
4547
* The function `isapprox(x,y)` now accepts the `norm` keyword argument also for numeric (i.e., non-array) arguments `x` and `y` ([#35883]).
4648
* `view`, `@view`, and `@views` now work on `AbstractString`s, returning a `SubString` when appropriate ([#35879]).

Diff for: base/Base.jl

+1
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ include("threads.jl")
223223
include("lock.jl")
224224
include("channels.jl")
225225
include("task.jl")
226+
include("threads_overloads.jl")
226227
include("weakkeydict.jl")
227228

228229
# Logging

Diff for: base/threadingconstructs.jl

+8
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,11 @@ macro spawn(expr)
180180
end
181181
end
182182
end
183+
184+
# This is a stub that can be overloaded for downstream structures like `Channel`
185+
function foreach end
186+
187+
# Scheduling traits that can be employed for downstream overloads
188+
abstract type AbstractSchedule end
189+
struct StaticSchedule <: AbstractSchedule end
190+
struct FairSchedule <: AbstractSchedule end

Diff for: base/threads_overloads.jl

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
"""
4+
Threads.foreach(f, channel::Channel;
5+
schedule::Threads.AbstractSchedule=Threads.FairSchedule(),
6+
ntasks=Threads.nthreads())
7+
8+
Similar to `foreach(f, channel)`, but iteration over `channel` and calls to
9+
`f` are split across `ntasks` tasks spawned by `Threads.@spawn`. This function
10+
will wait for all internally spawned tasks to complete before returning.
11+
12+
If `schedule isa FairSchedule`, `Threads.foreach` will attempt to spawn tasks in a
13+
manner that enables Julia's scheduler to more freely load-balance work items across
14+
threads. This approach generally has higher per-item overhead, but may perform
15+
better than `StaticSchedule` in concurrence with other multithreaded workloads.
16+
17+
If `schedule isa StaticSchedule`, `Threads.foreach` will spawn tasks in a manner
18+
that incurs lower per-item overhead than `FairSchedule`, but is less amenable
19+
to load-balancing. This approach thus may be more suitable for fine-grained,
20+
uniform workloads, but may perform worse than `FairSchedule` in concurrence
21+
with other multithreaded workloads.
22+
23+
!!! compat "Julia 1.6"
24+
This function requires Julia 1.6 or later.
25+
"""
26+
function Threads.foreach(f, channel::Channel;
27+
schedule::Threads.AbstractSchedule=Threads.FairSchedule(),
28+
ntasks=Threads.nthreads())
29+
apply = _apply_for_schedule(schedule)
30+
stop = Threads.Atomic{Bool}(false)
31+
@sync for _ in 1:ntasks
32+
Threads.@spawn try
33+
for item in channel
34+
$apply(f, item)
35+
# do `stop[] && break` after `f(item)` to avoid losing `item`.
36+
# this isn't super comprehensive since a task could still get
37+
# stuck on `take!` at `for item in channel`. We should think
38+
# about a more robust mechanism to avoid dropping items. See also:
39+
# https://github.com/JuliaLang/julia/pull/34543#discussion_r422695217
40+
stop[] && break
41+
end
42+
catch
43+
stop[] = true
44+
rethrow()
45+
end
46+
end
47+
return nothing
48+
end
49+
50+
_apply_for_schedule(::Threads.StaticSchedule) = (f, x) -> f(x)
51+
_apply_for_schedule(::Threads.FairSchedule) = (f, x) -> wait(Threads.@spawn f(x))

Diff for: test/threads_exec.jl

+35
Original file line numberDiff line numberDiff line change
@@ -845,3 +845,38 @@ fib34666(x) =
845845
f(x)
846846
end
847847
@test fib34666(25) == 75025
848+
849+
function jitter_channel(f, k, delay, ntasks, schedule)
850+
x = Channel(ch -> foreach(i -> put!(ch, i), 1:k), 1)
851+
y = Channel(k) do ch
852+
g = i -> begin
853+
iseven(i) && sleep(delay)
854+
put!(ch, f(i))
855+
end
856+
Threads.foreach(g, x; schedule=schedule, ntasks=ntasks)
857+
end
858+
return y
859+
end
860+
861+
@testset "Threads.foreach(f, ::Channel)" begin
862+
k = 50
863+
delay = 0.01
864+
expected = sin.(1:k)
865+
ordered_fair = collect(jitter_channel(sin, k, delay, 1, Threads.FairSchedule()))
866+
ordered_static = collect(jitter_channel(sin, k, delay, 1, Threads.StaticSchedule()))
867+
@test expected == ordered_fair
868+
@test expected == ordered_static
869+
870+
unordered_fair = collect(jitter_channel(sin, k, delay, 10, Threads.FairSchedule()))
871+
unordered_static = collect(jitter_channel(sin, k, delay, 10, Threads.StaticSchedule()))
872+
@test expected != unordered_fair
873+
@test expected != unordered_static
874+
@test Set(expected) == Set(unordered_fair)
875+
@test Set(expected) == Set(unordered_static)
876+
877+
ys = Channel() do ys
878+
inner = Channel(xs -> foreach(i -> put!(xs, i), 1:3))
879+
Threads.foreach(x -> put!(ys, x), inner)
880+
end
881+
@test sort!(collect(ys)) == 1:3
882+
end

0 commit comments

Comments
 (0)