Open
Description
Here's an example that does not play well with GPUCompiler
: https://gist.github.com/pabloferz/1390d85383e3243015be7ad5b162bcc4
A possible, but probably incomplete fix discussed with @mcabbott, is having the following specializations:
function ProjectTo(x::AbstractArray{T}) where {T <: AbstractFloat}
return ProjectTo{AbstractArray}(; element=ProjectTo(zero(T)), axes=axes(x))
end
ProjectTo(x::AbstractArray{T}) where {T <: Bool} = ProjectTo{NoTangent}()
function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S}) where {S <: Number}
T = ChainRulesCore.project_type(project.element)
return S <: T ? dx : map(project.element, dx)
end