|
47 | 47 | @inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
|
48 | 48 | axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
|
49 | 49 | isempty(dest) && return dest
|
| 50 | + |
| 51 | + # to help Enzyme.jl, we won't pass the broadcasted object directly |
| 52 | + # but instead pass its arguments and reconstruct the object device-side |
50 | 53 | bc = Broadcast.preprocess(dest, bc)
|
| 54 | + bcstyle = @static if VERSION >= v"1.10-" |
| 55 | + bc.style |
| 56 | + else |
| 57 | + typeof(BroadcastStyle(typeof(bc))) |
| 58 | + end |
51 | 59 |
|
52 | 60 | broadcast_kernel = if ndims(dest) == 1 ||
|
53 | 61 | (isa(IndexStyle(dest), IndexLinear) &&
|
54 | 62 | isa(IndexStyle(bc), IndexLinear))
|
55 |
| - function (ctx, dest, bc, nelem) |
| 63 | + function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) |
| 64 | + bc′ = @static if VERSION >= v"1.10-" |
| 65 | + Broadcasted(bcstyle, bcf, bcargs, bcaxes) |
| 66 | + else |
| 67 | + Broadcasted{bcstyle}(bcf, bcargs, bcaxes) |
| 68 | + end |
| 69 | + |
56 | 70 | i = 1
|
57 | 71 | while i <= nelem
|
58 | 72 | I = @linearidx(dest, i)
|
59 |
| - @inbounds dest[I] = bc[I] |
| 73 | + @inbounds dest[I] = bc′[I] |
60 | 74 | i += 1
|
61 | 75 | end
|
62 | 76 | return
|
63 | 77 | end
|
64 | 78 | else
|
65 |
| - function (ctx, dest, bc, nelem) |
| 79 | + function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) |
| 80 | + bc′ = @static if VERSION >= v"1.10-" |
| 81 | + Broadcasted(bcstyle, bcf, bcargs, bcaxes) |
| 82 | + else |
| 83 | + Broadcasted{bcstyle}(bcf, bcargs, bcaxes) |
| 84 | + end |
| 85 | + |
66 | 86 | i = 0
|
67 | 87 | while i < nelem
|
68 | 88 | i += 1
|
69 | 89 | I = @cartesianidx(dest, i)
|
70 |
| - @inbounds dest[I] = bc[I] |
| 90 | + @inbounds dest[I] = bc′[I] |
71 | 91 | end
|
72 | 92 | return
|
73 | 93 | end
|
74 | 94 | end
|
75 | 95 |
|
76 | 96 | elements = length(dest)
|
77 | 97 | elements_per_thread = typemax(Int)
|
78 |
| - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1; |
| 98 | + heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, 1, |
| 99 | + bcstyle, bc.f, bc.axes, bc.args...; |
79 | 100 | elements, elements_per_thread)
|
80 | 101 | config = launch_configuration(backend(dest), heuristic;
|
81 | 102 | elements, elements_per_thread)
|
82 |
| - gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread; |
| 103 | + gpu_call(broadcast_kernel, dest, config.elements_per_thread::Int, |
| 104 | + bcstyle, bc.f, bc.axes, bc.args...; |
83 | 105 | threads=config.threads, blocks=config.blocks)
|
84 | 106 |
|
85 | 107 | if eltype(dest) <: BrokenBroadcast
|
|
0 commit comments