Description
As the title says. Code to see this:
using CUDA, Zygote
function test_func(a, b)
return sum(abs2, a .+ b')
end
a = CUDA.rand(ComplexF64, 3)
b = CUDA.rand(3)
gradient(test_func, a, b)
Produces:
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:35
[2] assertscalar(op::String)
@ GPUArraysCore C:\Users\domin\.julia\packages\GPUArraysCore\uOYfN\src\GPUArraysCore.jl:103
[3] getindex(::CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
@ GPUArrays C:\Users\domin\.julia\packages\GPUArrays\5XhED\src\host\indexing.jl:9
[4] getindex
@ C:\Users\domin\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\LinearAlgebra\src\adjtrans.jl:303 [inlined]
[5] _unsafe_getindex_rs
@ .\reshapedarray.jl:251 [inlined]
[6] _unsafe_getindex
@ .\reshapedarray.jl:248 [inlined]
[7] getindex
@ .\reshapedarray.jl:236 [inlined]
[8] iterate
@ .\abstractarray.jl:1220 [inlined]
[9] iterate
@ .\abstractarray.jl:1218 [inlined]
[10] iterate
@ .\generator.jl:44 [inlined]
[11] _collect(c::Base.ReshapedArray{ComplexF64, 1, LinearAlgebra.Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, itr::Base.Generator{Base.ReshapedArray{ComplexF64, 1, LinearAlgebra.Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
@ Base .\array.jl:802
[12] collect_similar
@ .\array.jl:711 [inlined]
[13] map
@ .\abstractarray.jl:3261 [inlined]
[14] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::LinearAlgebra.Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}})
@ ChainRulesCore C:\Users\domin\.julia\packages\ChainRulesCore\0t04l\src\projection.jl:236
[15] ProjectTo
@ C:\Users\domin\.julia\packages\ChainRulesCore\0t04l\src\projection.jl:414 [inlined]
[16] _project
@ C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\chainrules.jl:189 [inlined]
[17] unbroadcast(x::LinearAlgebra.Adjoint{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}, x̄::CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\lib\broadcast.jl:62
[18] #1172
@ C:\Users\domin\.julia\packages\Zygote\JeHtr\src\lib\broadcast.jl:83 [inlined]
[19] map
@ .\tuple.jl:274 [inlined]
[20] #1171
@ C:\Users\domin\.julia\packages\Zygote\JeHtr\src\lib\broadcast.jl:83 [inlined]
[21] #3754#back
@ C:\Users\domin\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:71 [inlined]
[22] Pullback
@ .\REPL[1]:2 [inlined]
[23] (::Zygote.Pullback{Tuple{typeof(test_func), CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.ZBack{typeof(ChainRules._adjoint_vec_pullback)}, ComposedFunction{Zygote.Pullback{Tuple{Zygote.var"#1441#1442", typeof(abs2), CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#4197#back#1437"{Zygote.var"#1433#1436"{CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3978#back#1283"{Zygote.var"#1279#1282"{CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}}}}, typeof(ZygoteRules.unthunk_tangent)}, Zygote.var"#3754#back#1177"{Zygote.var"#1171#1175"{Tuple{CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}}})(Δ::Float64)
@ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface2.jl:0
[24] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(test_func), CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.ZBack{typeof(ChainRules._adjoint_vec_pullback)}, ComposedFunction{Zygote.Pullback{Tuple{Zygote.var"#1441#1442", typeof(abs2), CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#4197#back#1437"{Zygote.var"#1433#1436"{CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3978#back#1283"{Zygote.var"#1279#1282"{CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}}}}, typeof(ZygoteRules.unthunk_tangent)}, Zygote.var"#3754#back#1177"{Zygote.var"#1171#1175"{Tuple{CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}}}})(Δ::Float64)
@ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:45
[25] gradient(::Function, ::CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
@ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:97
[26] top-level scope
@ REPL[11]:1
[27] top-level scope
@ C:\Users\domin\.julia\packages\CUDA\tVtYo\src\initialization.jl:185
Interestingly, making a real and b complex allows it to run, but errors on display as the output type for the b gradient is becomes Base.ReshapedArray{ComplexF64, 1, LinearAlgebra.Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}
which it refuses to print. Collecting that array produces a CuArray with the correct gradient.
The issue (at least with a complex and b real) seems to stem from
ChainRulesCore.jl/src/projection.jl
Line 413 in 872d645
Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}
, which when then reshaped at ChainRulesCore.jl/src/projection.jl
Line 230 in 872d645
Base.ReshapedArray{ComplexF64, 1, Adjoint{ComplexF64, CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}
. Dispatch then sees this as an AbstractArray and sends it down base paths when map is called at ChainRulesCore.jl/src/projection.jl
Line 236 in 872d645
When a is real and b is complex, the element type of the gradient S matches the element type of the primal T, so the map in
ChainRulesCore.jl/src/projection.jl
Line 236 in 872d645
As far as I understand it, this would ideally be fixed by better wrapper array handling in Base / CuArray, but that seems like a hard and long lived issue. In the meantime I'm not sure what the best way to fix this would be, and whether that responsibility lies with CUDA or ChainRulesCore. Given the leaking of the wrapped array as a gradient of b in the a real, b complex case, perhaps there could be some tweaks to wrapped array handling here. Perhaps when the typeof dx is an Adjoint(...) then the reshape should be replaced by an adjoint followed by a broadcast of conj, or the earlier adjoint call in the ProjectTo{Adjoint} method should be a conj broadcast instead? Not sure what would be correct.