Skip to content

Commit 8c5d550

Browse files
authored
Reconstruct Broadcasted in kernel to help Enzyme.jl (#539)
1 parent 800c237 commit 8c5d550

File tree

1 file changed

+28
-6
lines changed

1 file changed

+28
-6
lines changed

Diff for: src/host/broadcast.jl

+28-6
Original file line numberDiff line numberDiff line change
@@ -47,39 +47,61 @@ end
4747
@inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
4848
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
4949
isempty(dest) && return dest
50+
51+
# to help Enzyme.jl, we won't pass the broadcasted object directly
52+
# but instead pass its arguments and reconstruct the object device-side
5053
bc = Broadcast.preprocess(dest, bc)
54+
bcstyle = @static if VERSION >= v"1.10-"
55+
bc.style
56+
else
57+
typeof(BroadcastStyle(typeof(bc)))
58+
end
5159

5260
broadcast_kernel = if ndims(dest) == 1 ||
5361
(isa(IndexStyle(dest), IndexLinear) &&
5462
isa(IndexStyle(bc), IndexLinear))
55-
function (ctx, dest, bc, nelem)
63+
function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...)
64+
bc′ = @static if VERSION >= v"1.10-"
65+
Broadcasted(bcstyle, bcf, bcargs, bcaxes)
66+
else
67+
Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
68+
end
69+
5670
i = 1
5771
while i <= nelem
5872
I = @linearidx(dest, i)
59-
@inbounds dest[I] = bc[I]
73+
@inbounds dest[I] = bc[I]
6074
i += 1
6175
end
6276
return
6377
end
6478
else
65-
function (ctx, dest, bc, nelem)
79+
function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...)
80+
bc′ = @static if VERSION >= v"1.10-"
81+
Broadcasted(bcstyle, bcf, bcargs, bcaxes)
82+
else
83+
Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
84+
end
85+
6686
i = 0
6787
while i < nelem
6888
i += 1
6989
I = @cartesianidx(dest, i)
70-
@inbounds dest[I] = bc[I]
90+
@inbounds dest[I] = bc[I]
7191
end
7292
return
7393
end
7494
end
7595

7696
elements = length(dest)
7797
elements_per_thread = typemax(Int)
78-
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1;
98+
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, 1,
99+
bcstyle, bc.f, bc.axes, bc.args...;
79100
elements, elements_per_thread)
80101
config = launch_configuration(backend(dest), heuristic;
81102
elements, elements_per_thread)
82-
gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread;
103+
gpu_call(broadcast_kernel, dest, config.elements_per_thread::Int,
104+
bcstyle, bc.f, bc.axes, bc.args...;
83105
threads=config.threads, blocks=config.blocks)
84106

85107
if eltype(dest) <: BrokenBroadcast

0 commit comments

Comments
 (0)