Skip to content

Projections do not play well with GPUCompiler #429

Open
@pabloferz

Description

@pabloferz

Here's an example that does not play well with GPUCompiler: https://gist.github.com/pabloferz/1390d85383e3243015be7ad5b162bcc4

A possible, but probably incomplete fix discussed with @mcabbott, is having the following specializations:

function ProjectTo(x::AbstractArray{T}) where {T <: AbstractFloat}
    return ProjectTo{AbstractArray}(; element=ProjectTo(zero(T)), axes=axes(x))
end

ProjectTo(x::AbstractArray{T}) where {T <: Bool} = ProjectTo{NoTangent}()

function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S}) where {S <: Number}
    T = ChainRulesCore.project_type(project.element)
    return S <: T ? dx : map(project.element, dx)
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    ProjectTorelated to the projection functionality

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions