From d87ca550faa867fa09f2dd535f3c2288252c02d4 Mon Sep 17 00:00:00 2001
From: James Schloss <jrs.schloss@gmail.com>
Date: Mon, 22 Jul 2024 11:15:45 +0200
Subject: [PATCH 1/4] Transition GPUArrays to KernelAbstractions

---
 .buildkite/pipeline.yml                |  11 +-
 Project.toml                           |   1 +
 docs/src/index.md                      |   5 +-
 docs/src/interface.md                  |  57 +++----
 lib/GPUArraysCore/src/GPUArraysCore.jl |   6 +-
 lib/JLArrays/Project.toml              |   3 +-
 lib/JLArrays/src/JLArrays.jl           | 200 ++++++++++--------------
 src/GPUArrays.jl                       |   8 +-
 src/device/execution.jl                |  83 +---------
 src/device/indexing.jl                 |  85 ----------
 src/device/memory.jl                   |  27 ----
 src/device/synchronization.jl          |  13 --
 src/host/abstractarray.jl              |  32 ++--
 src/host/base.jl                       |  28 ++--
 src/host/broadcast.jl                  |  74 ++++-----
 src/host/construction.jl               |  28 ++--
 src/host/indexing.jl                   |  26 ++--
 src/host/linalg.jl                     | 207 ++++++++++++-------------
 src/host/math.jl                       |   6 +-
 src/host/random.jl                     |  38 ++---
 src/host/uniformscaling.jl             |  28 ++--
 test/Project.toml                      |   1 +
 test/runtests.jl                       |  17 ++
 test/testsuite.jl                      |   2 +-
 test/testsuite/base.jl                 |  27 ++--
 test/testsuite/broadcasting.jl         |   3 +-
 test/testsuite/gpuinterface.jl         |  47 ------
 27 files changed, 374 insertions(+), 689 deletions(-)
 delete mode 100644 src/device/indexing.jl
 delete mode 100644 src/device/memory.jl
 delete mode 100644 src/device/synchronization.jl
 delete mode 100644 test/testsuite/gpuinterface.jl

diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
index b7f4aa82..411ca23e 100644
--- a/.buildkite/pipeline.yml
+++ b/.buildkite/pipeline.yml
@@ -10,7 +10,7 @@ steps:
 
                 println("--- :julia: Instantiating project")
                 Pkg.develop(; path=pwd())
-                Pkg.develop(; name="CUDA")
+                Pkg.add(; url="https://github.com/leios/CUDA.jl/", rev="GtK_trans")
 
                 println("+++ :julia: Running tests")
                 Pkg.test("CUDA"; coverage=true)'
@@ -31,10 +31,13 @@ steps:
 
                 println("--- :julia: Instantiating project")
                 Pkg.develop(; path=pwd())
-                Pkg.develop(; name="oneAPI")
+                Pkg.add(; url="https://github.com/leios/oneAPI.jl/", rev="GtK_transition")
 
                 println("+++ :julia: Building support library")
-                include(joinpath(Pkg.devdir(), "oneAPI", "deps", "build_ci.jl"))
+                filename = Base.find_package("oneAPI")
+                filename = filename[1:findfirst("oneAPI.jl", filename)[1]-1]
+                filename *= "../deps/build_ci.jl"
+                include(filename)
                 Pkg.activate()
 
                 println("+++ :julia: Running tests")
@@ -56,7 +59,7 @@ steps:
 
                 println("--- :julia: Instantiating project")
                 Pkg.develop(; path=pwd())
-                Pkg.develop(; name="Metal")
+                Pkg.add(; url="https://github.com/leios/Metal.jl/", rev="GtK_transition")
 
                 println("+++ :julia: Running tests")
                 Pkg.test("Metal"; coverage=true)'
diff --git a/Project.toml b/Project.toml
index f2a4e367..c313a2a6 100644
--- a/Project.toml
+++ b/Project.toml
@@ -5,6 +5,7 @@ version = "10.3.1"
 [deps]
 Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
 GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
+KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
 LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
diff --git a/docs/src/index.md b/docs/src/index.md
index cfdd3272..8cb100ca 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -9,10 +9,9 @@ will get a lot of functionality for free. This will allow to have multiple GPUAr
 implementation for different purposes, while maximizing the ability to share code.
 
 **This package is not intended for end users!** Instead, you should use one of the packages
-that builds on GPUArrays.jl. There is currently only a single package that actively builds
-on these interfaces, namely [CuArrays.jl](https://github.com/JuliaGPU/CuArrays.jl).
+that builds on GPUArrays.jl such as [CUDA](https://github.com/JuliaGPU/CUDA.jl), [AMDGPU](https://github.com/JuliaGPU/AMDGPU.jl), [OneAPI](https://github.com/JuliaGPU/oneAPI.jl), or [Metal](https://github.com/JuliaGPU/Metal.jl).
 
-In this documentation, you will find more information on the interface that you are expected
+This documentation is meant for users who might wish to implement a version of GPUArrays for another GPU backend and will cover the features you will need
 to implement, the functionality you gain by doing so, and the test suite that is available
 to verify your implementation. GPUArrays.jl also provides a reference implementation of
 these interfaces on the CPU: The `JLArray` array type uses Julia's parallel programming
diff --git a/docs/src/interface.md b/docs/src/interface.md
index 01c0a3c9..239bef87 100644
--- a/docs/src/interface.md
+++ b/docs/src/interface.md
@@ -1,53 +1,32 @@
 # Interface
 
 To extend the above functionality to a new array type, you should use the types and
-implement the interfaces listed on this page. GPUArrays is design around having two
-different array types to represent a GPU array: one that only ever lives on the host, and
+implement the interfaces listed on this page. GPUArrays is designed around having two
+different array types to represent a GPU array: one that exists only on the host, and
 one that actually can be instantiated on the device (i.e. in kernels).
+Device functionality is then handled by [KernelAbstractions.jl](https://github.com/JuliaGPU/KernelAbstractions.jl).
 
+## Host abstractions
 
-## Device functionality
-
-Several types and interfaces are related to the device and execution of code on it. First of
-all, you need to provide a type that represents your execution back-end and a way to call
-kernels:
+You should provide an array type that builds on the `AbstractGPUArray` supertype, such as:
 
-```@docs
-GPUArrays.AbstractGPUBackend
-GPUArrays.AbstractKernelContext
-GPUArrays.gpu_call
-GPUArrays.thread_block_heuristic
 ```
+mutable struct CustomArray{T, N} <: AbstractGPUArray{T, N}
+    data::DataRef{Vector{UInt8}}
+    offset::Int
+    dims::Dims{N}
+    ...
+end
 
-You then need to provide implementations of certain methods that will be executed on the
-device itself:
-
-```@docs
-GPUArrays.AbstractDeviceArray
-GPUArrays.LocalMemory
-GPUArrays.synchronize_threads
-GPUArrays.blockidx
-GPUArrays.blockdim
-GPUArrays.threadidx
-GPUArrays.griddim
 ```
 
+This will allow your defined type (in this case `JLArray`) to use the GPUArrays interface where available.
+To be able to actually use the functionality that is defined for `AbstractGPUArray`s, you need to define the backend, like so:
 
-## Host abstractions
-
-You should provide an array type that builds on the `AbstractGPUArray` supertype:
-
-```@docs
-AbstractGPUArray
 ```
-
-First of all, you should implement operations that are expected to be defined for any
-`AbstractArray` type. Refer to the Julia manual for more details, or look at the `JLArray`
-reference implementation.
-
-To be able to actually use the functionality that is defined for `AbstractGPUArray`s, you
-should provide implementations of the following interfaces:
-
-```@docs
-GPUArrays.backend
+import KernelAbstractions: Backend
+struct CustomBackend <: KernelAbstractions.GPU
+KernelAbstractions.get_backend(a::CA) where CA <: CustomArray = CustomBackend()
 ```
+
+There are numerous examples of potential interfaces for GPUArrays, such as with [JLArrays](https://github.com/JuliaGPU/GPUArrays.jl/blob/master/lib/JLArrays/src/JLArrays.jl), [CuArrays](https://github.com/JuliaGPU/CUDA.jl/blob/master/src/gpuarrays.jl), and [ROCArrays](https://github.com/JuliaGPU/AMDGPU.jl/blob/master/src/gpuarrays.jl).
diff --git a/lib/GPUArraysCore/src/GPUArraysCore.jl b/lib/GPUArraysCore/src/GPUArraysCore.jl
index d5e9a090..117c1e7e 100644
--- a/lib/GPUArraysCore/src/GPUArraysCore.jl
+++ b/lib/GPUArraysCore/src/GPUArraysCore.jl
@@ -218,10 +218,10 @@ end
 
 Gets the GPUArrays back-end responsible for managing arrays of type `T`.
 """
-backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE
-backend(x) = backend(typeof(x))
+get_backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE
+get_backend(x) = get_backend(typeof(x))
 
 # WrappedArray from Adapt for Base wrappers.
-backend(::Type{WA}) where WA<:WrappedArray = backend(unwrap_type(WA))
+get_backend(::Type{WA}) where WA<:WrappedArray = backend(unwrap_type(WA))
 
 end # module GPUArraysCore
diff --git a/lib/JLArrays/Project.toml b/lib/JLArrays/Project.toml
index 60c33183..40af0dba 100644
--- a/lib/JLArrays/Project.toml
+++ b/lib/JLArrays/Project.toml
@@ -6,10 +6,11 @@ version = "0.1.5"
 [deps]
 Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
 GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
+KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 
 [compat]
 Adapt = "2.0, 3.0, 4.0"
 GPUArrays = "10"
-julia = "1.8"
 Random = "1"
+julia = "1.8"
diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl
index 10d53daa..0dba44e5 100644
--- a/lib/JLArrays/src/JLArrays.jl
+++ b/lib/JLArrays/src/JLArrays.jl
@@ -1,16 +1,17 @@
 # reference implementation on the CPU
-
-# note that most of the code in this file serves to define a functional array type,
-# the actual implementation of GPUArrays-interfaces is much more limited.
+# This acts as a wrapper around KernelAbstractions's parallel CPU
+# functionality. It is useful for testing GPUArrays (and other packages) 
+# when no GPU is present.
+# This file follows conventions from AMDGPU.jl
 
 module JLArrays
 
-export JLArray, JLVector, JLMatrix, jl
-
 using GPUArrays
-
 using Adapt
+import KernelAbstractions
+import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
 
+export JLArray, JLVector, JLMatrix, jl, JLBackend
 
 #
 # Device functionality
@@ -18,37 +19,11 @@ using Adapt
 
 const MAXTHREADS = 256
 
-
-## execution
-
-struct JLBackend <: AbstractGPUBackend end
-
-mutable struct JLKernelContext <: AbstractKernelContext
-    blockdim::Int
-    griddim::Int
-    blockidx::Int
-    threadidx::Int
-
-    localmem_counter::Int
-    localmems::Vector{Vector{Array}}
+struct JLBackend <: KernelAbstractions.GPU
+    static::Bool
+    JLBackend(;static::Bool=false) = new(static)
 end
 
-function JLKernelContext(threads::Int, blockdim::Int)
-    blockcount = prod(blockdim)
-    lmems = [Vector{Array}() for i in 1:blockcount]
-    JLKernelContext(threads, blockdim, 1, 1, 0, lmems)
-end
-
-function JLKernelContext(ctx::JLKernelContext, threadidx::Int)
-    JLKernelContext(
-        ctx.blockdim,
-        ctx.griddim,
-        ctx.blockidx,
-        threadidx,
-        0,
-        ctx.localmems
-    )
-end
 
 struct Adaptor end
 jlconvert(arg) = adapt(Adaptor(), arg)
@@ -60,28 +35,37 @@ end
 Base.getindex(r::JlRefValue) = r.x
 Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[]))
 
-function GPUArrays.gpu_call(::JLBackend, f, args, threads::Int, blocks::Int;
-                            name::Union{String,Nothing})
-    ctx = JLKernelContext(threads, blocks)
-    device_args = jlconvert.(args)
-    tasks = Array{Task}(undef, threads)
-    for blockidx in 1:blocks
-        ctx.blockidx = blockidx
-        for threadidx in 1:threads
-            thread_ctx = JLKernelContext(ctx, threadidx)
-            tasks[threadidx] = @async f(thread_ctx, device_args...)
-            # TODO: require 1.3 and use Base.Threads.@spawn for actual multithreading
-            #       (this would require a different synchronization mechanism)
-        end
-        for t in tasks
-            fetch(t)
+mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
+    data::DataRef{Vector{UInt8}}
+
+    offset::Int        # offset of the data in the buffer, in number of elements
+
+    dims::Dims{N}
+
+    # allocating constructor
+    function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
+        check_eltype(T)
+        maxsize = prod(dims) * sizeof(T)
+        data = Vector{UInt8}(undef, maxsize)
+        ref = DataRef(data) do data
+            resize!(data, 0)
         end
+        obj = new{T,N}(ref, 0, dims)
+        finalizer(unsafe_free!, obj)
     end
-    return
-end
 
+    # low-level constructor for wrapping existing data
+    function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N};
+                          offset::Int=0) where {T,N}
+        check_eltype(T)
+        obj = new{T,N}(ref, offset, dims)
+        finalizer(unsafe_free!, obj)
+    end
+end
 
-## executed on-device
+Adapt.adapt_storage(::JLBackend, a::Array) = Adapt.adapt(JLArrays.JLArray, a)
+Adapt.adapt_storage(::JLBackend, a::JLArrays.JLArray) = a
+Adapt.adapt_storage(::KernelAbstractions.CPU, a::JLArrays.JLArray) = convert(Array, a)
 
 # array type
 
@@ -107,43 +91,6 @@ end
 @inline Base.getindex(A::JLDeviceArray, index::Integer) = getindex(typed_data(A), index)
 @inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(typed_data(A), x, index)
 
-
-# indexing
-
-for f in (:blockidx, :blockdim, :threadidx, :griddim)
-    @eval GPUArrays.$f(ctx::JLKernelContext) = ctx.$f
-end
-
-# memory
-
-function GPUArrays.LocalMemory(ctx::JLKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
-    ctx.localmem_counter += 1
-    lmems = ctx.localmems[blockidx(ctx)]
-
-    # first invocation in block
-    data = if length(lmems) < ctx.localmem_counter
-        lmem = fill(zero(T), dims)
-        push!(lmems, lmem)
-        lmem
-    else
-        lmems[ctx.localmem_counter]
-    end
-
-    N = length(dims)
-    JLDeviceArray{T,N}(data, tuple(dims...))
-end
-
-# synchronization
-
-@inline function GPUArrays.synchronize_threads(::JLKernelContext)
-    # All threads are getting started asynchronously, so a yield will yield to the next
-    # execution of the same function, which should call yield at the exact same point in the
-    # program, leading to a chain of yields effectively syncing the tasks (threads).
-    yield()
-    return
-end
-
-
 #
 # Host abstractions
 #
@@ -157,34 +104,6 @@ function check_eltype(T)
   end
 end
 
-mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
-    data::DataRef{Vector{UInt8}}
-
-    offset::Int        # offset of the data in the buffer, in number of elements
-
-    dims::Dims{N}
-
-    # allocating constructor
-    function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
-        check_eltype(T)
-        maxsize = prod(dims) * sizeof(T)
-        data = Vector{UInt8}(undef, maxsize)
-        ref = DataRef(data) do data
-            resize!(data, 0)
-        end
-        obj = new{T,N}(ref, 0, dims)
-        finalizer(unsafe_free!, obj)
-    end
-
-    # low-level constructor for wrapping existing data
-    function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N};
-                          offset::Int=0) where {T,N}
-        check_eltype(T)
-        obj = new{T,N}(ref, offset, dims)
-        finalizer(unsafe_free!, obj)
-    end
-end
-
 unsafe_free!(a::JLArray) = GPUArrays.unsafe_free!(a.data)
 
 # conversion of untyped data to a typed Array
@@ -409,8 +328,6 @@ end
 
 ## GPUArrays interfaces
 
-GPUArrays.backend(::Type{<:JLArray}) = JLBackend()
-
 Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
   JLDeviceArray{T,N}(x.data[], x.offset, x.dims)
 
@@ -423,4 +340,47 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
     R
 end
 
+## KernelAbstractions interface
+
+KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()
+
+function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
+    return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
+end
+
+KernelAbstractions.allocate(::JLBackend, ::Type{T}, dims::Tuple) where T = JLArray{T}(undef, dims)
+
+@inline function launch_config(kernel::Kernel{JLBackend}, ndrange, workgroupsize)
+    if ndrange isa Integer
+        ndrange = (ndrange,)
+    end
+    if workgroupsize isa Integer
+        workgroupsize = (workgroupsize, )
+    end
+
+    if KernelAbstractions.workgroupsize(kernel) <: DynamicSize && workgroupsize === nothing
+        workgroupsize = (1024,) # Vectorization, 4x unrolling, minimal grain size
+    end
+    iterspace, dynamic = partition(kernel, ndrange, workgroupsize)
+    # partition checked that the ndrange's agreed
+    if KernelAbstractions.ndrange(kernel) <: StaticSize
+        ndrange = nothing
+    end
+
+    return ndrange, workgroupsize, iterspace, dynamic
+end
+
+KernelAbstractions.isgpu(b::JLBackend) = false
+
+function convert_to_cpu(obj::Kernel{JLBackend, W, N, F}) where {W, N, F}
+    return Kernel{typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F}(KernelAbstractions.CPU(; static = obj.backend.static), obj.f)
+end
+
+function (obj::Kernel{JLBackend})(args...; ndrange=nothing, workgroupsize=nothing)
+    device_args = jlconvert.(args)
+    new_obj = convert_to_cpu(obj)
+    new_obj(device_args...; ndrange, workgroupsize)
+
+end
+
 end
diff --git a/src/GPUArrays.jl b/src/GPUArrays.jl
index 2d4f1bd9..54e9c877 100644
--- a/src/GPUArrays.jl
+++ b/src/GPUArrays.jl
@@ -1,5 +1,6 @@
 module GPUArrays
 
+using KernelAbstractions
 using Serialization
 using Random
 using LinearAlgebra
@@ -14,14 +15,11 @@ using LLVM.Interop
 using Reexport
 @reexport using GPUArraysCore
 
-# device functionality
-include("device/execution.jl")
 ## executed on-device
+include("device/execution.jl")
 include("device/abstractarray.jl")
-include("device/indexing.jl")
-include("device/memory.jl")
-include("device/synchronization.jl")
 
+using KernelAbstractions
 # host abstractions
 include("host/abstractarray.jl")
 include("host/construction.jl")
diff --git a/src/device/execution.jl b/src/device/execution.jl
index 41285bc3..64a81dad 100644
--- a/src/device/execution.jl
+++ b/src/device/execution.jl
@@ -1,75 +1,5 @@
 # kernel execution
 
-export AbstractGPUBackend, AbstractKernelContext, gpu_call
-
-abstract type AbstractGPUBackend end
-
-abstract type AbstractKernelContext end
-
-import GPUArraysCore: backend
-
-"""
-    gpu_call(kernel::Function, arg0, args...; kwargs...)
-
-Executes `kernel` on the device that backs `arg` (see [`backend`](@ref)), passing along any
-arguments `args`. Additionally, the kernel will be passed the kernel execution context (see
-[`AbstractKernelContext`]), so its signature should be `(ctx::AbstractKernelContext, arg0,
-args...)`.
-
-The keyword arguments `kwargs` are not passed to the function, but are interpreted on the
-host to influence how the kernel is executed. The following keyword arguments are supported:
-
-- `target::AbstractArray`: specify which array object to use for determining execution
-  properties (defaults to the first argument `arg0`).
-- `elements::Int`: how many elements will be processed by this kernel. In most
-  circumstances, this will correspond to the total number of threads that needs to be
-  launched, unless the kernel supports a variable number of elements to process per
-  iteration. Defaults to the length of `arg0` if no other keyword arguments that influence
-  the launch configuration are specified.
-- `threads::Int` and `blocks::Int`: configure exactly how many threads and blocks are
-  launched. This cannot be used in combination with the `elements` argument.
-- `name::String`: inform the back end about the name of the kernel to be executed. This can
-  be used to emit better diagnostics, and is useful with anonymous kernels.
-"""
-function gpu_call(kernel::F, args::Vararg{Any,N};
-                  target::AbstractArray=first(args),
-                  elements::Union{Int,Nothing}=nothing,
-                  threads::Union{Int,Nothing}=nothing,
-                  blocks::Union{Int,Nothing}=nothing,
-                  name::Union{String,Nothing}=nothing) where {F,N}
-    # non-trivial default values for launch configuration
-    if elements===nothing && threads===nothing && blocks===nothing
-        elements = length(target)
-    elseif elements===nothing
-        if threads === nothing
-            threads = 1
-        end
-        if blocks === nothing
-            blocks = 1
-        end
-    elseif threads!==nothing || blocks!==nothing
-        error("Cannot specify both elements and threads/blocks configuration")
-    end
-
-    # the number of elements to process needs to be passed to the kernel somehow, so there's
-    # no easy way to do this without passing additional arguments or changing the context.
-    # both are expensive, so require manual use of `launch_heuristic` for those kernels.
-    elements_per_thread = 1
-
-    if elements !== nothing
-        @assert elements > 0
-        heuristic = launch_heuristic(backend(target), kernel, args...;
-                                     elements, elements_per_thread)
-        config = launch_configuration(backend(target), heuristic;
-                                      elements, elements_per_thread)
-        gpu_call(backend(target), kernel, args, config.threads, config.blocks; name=name)
-    else
-        @assert threads > 0
-        @assert blocks > 0
-        gpu_call(backend(target), kernel, args, threads, blocks; name=name)
-    end
-end
-
 # how many threads and blocks `kernel` needs to be launched with, passing arguments `args`,
 # to fully saturate the GPU. `elements` indicates the number of elements that needs to be
 # processed, while `elements_per_threads` indicates the number of elements this kernel can
@@ -77,16 +7,18 @@ end
 #
 # this heuristic should be specialized for the back-end, ideally using an API for maximizing
 # the occupancy of the launch configuration (like CUDA's occupancy API).
-function launch_heuristic(backend::AbstractGPUBackend, kernel, args...;
-                          elements::Int, elements_per_thread::Int)
+function launch_heuristic(backend::B, kernel, args...;
+                          elements::Int,
+                          elements_per_thread::Int) where B <: Backend
     return (threads=256, blocks=32)
 end
 
 # determine how many threads and blocks to actually launch given upper limits.
 # returns a tuple of blocks, threads, and elements_per_thread (which is always 1
 # unless specified that the kernel can handle a number of elements per thread)
-function launch_configuration(backend::AbstractGPUBackend, heuristic;
-                              elements::Int, elements_per_thread::Int)
+function launch_configuration(backend::B, heuristic;
+                              elements::Int,
+                              elements_per_thread::Int) where B <: Backend
     threads = clamp(elements, 1, heuristic.threads)
     blocks = max(cld(elements, threads), 1)
 
@@ -105,6 +37,3 @@ function launch_configuration(backend::AbstractGPUBackend, heuristic;
         (; threads, blocks, elements_per_thread=1)
     end
 end
-
-gpu_call(backend::AbstractGPUBackend, kernel, args, threads::Int, blocks::Int; kwargs...) =
-    error("Not implemented") # COV_EXCL_LINE
diff --git a/src/device/indexing.jl b/src/device/indexing.jl
deleted file mode 100644
index 31084fce..00000000
--- a/src/device/indexing.jl
+++ /dev/null
@@ -1,85 +0,0 @@
-# indexing
-
-export global_index, global_size, linear_index, @linearidx, @cartesianidx
-
-
-## hardware
-
-for f in (:blockidx, :blockdim, :threadidx, :griddim)
-    @eval $f(ctx::AbstractKernelContext)::Int = error("Not implemented") # COV_EXCL_LINE
-    @eval export $f
-end
-
-"""
-    global_index(ctx::AbstractKernelContext)
-
-Query the global index of the current thread in the launch configuration (i.e. as far as the
-hardware is concerned).
-"""
-@inline function global_index(ctx::AbstractKernelContext)
-    threadidx(ctx) + (blockidx(ctx) - 1) * blockdim(ctx)
-end
-
-"""
-    global_size(ctx::AbstractKernelContext)
-
-Query the global size of the launch configuration (total number of threads launched).
-"""
-@inline function global_size(ctx::AbstractKernelContext)
-    griddim(ctx) * blockdim(ctx)
-end
-
-
-## logical
-
-"""
-    linear_index(ctx::AbstractKernelContext, grididx::Int=1)
-
-Return a linear index for the current kernel by querying hardware registers (similar to
-`get_global_id` in OpenCL). For applying a grid stride (in terms of [`global_size`](@ref)),
-specify `grididx`.
-
-"""
-@inline function linear_index(ctx::AbstractKernelContext, grididx::Int=1)
-    global_index(ctx) + (grididx - 1) * global_size(ctx)
-end
-
-"""
-    linearidx(A, grididx=1, ctxsym=:ctx)
-
-Macro form of [`linear_index`](@ref), which return from the surrouunding scope when out of
-bounds:
-
-    ```julia
-    function kernel(ctx::AbstractKernelContext, A)
-        idx = @linearidx A
-        # from here on it's safe to index into A with idx
-        @inbounds begin
-            A[idx] = ...
-        end
-    end
-    ```
-"""
-macro linearidx(A, grididx=1, ctxsym=:ctx)
-    quote
-        x = $(esc(A))
-        i = linear_index($(esc(ctxsym)), $(esc(grididx)))
-        if !(1 <= i <= length(x))
-            return
-        end
-        i
-    end
-end
-
-"""
-    cartesianidx(A, grididx=1, ctxsym=:ctx)
-
-Like [`@linearidx`](@ref), but returns a N-dimensional `CartesianIndex`.
-"""
-macro cartesianidx(A, grididx=1, ctxsym=:ctx)
-    quote
-        x = $(esc(A))
-        i = @linearidx(x, $(esc(grididx)), $(esc(ctxsym)))
-        @inbounds CartesianIndices(x)[i]
-    end
-end
diff --git a/src/device/memory.jl b/src/device/memory.jl
deleted file mode 100644
index 901791d5..00000000
--- a/src/device/memory.jl
+++ /dev/null
@@ -1,27 +0,0 @@
-# on-device memory management
-
-export @LocalMemory
-
-
-## thread-local array
-
-"""
-Creates a local static memory shared inside one block.
-Equivalent to `__local` of OpenCL or `__shared__ (<variable>)` of CUDA.
-"""
-macro LocalMemory(ctx, T, N)
-    id = gensym("local_memory")
-    quote
-        LocalMemory($(esc(ctx)), $(esc(T)), Val($(esc(N))), Val($(QuoteNode(id))))
-    end
-end
-
-"""
-Creates a block local array pointer with `T` being the element type
-and `N` the length. Both T and N need to be static! C is a counter for
-approriately get the correct Local mem id in CUDAnative.
-This is an internal method which needs to be overloaded by the GPU Array backends
-"""
-function LocalMemory(ctx, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
-    error("Not implemented") # COV_EXCL_LINE
-end
diff --git a/src/device/synchronization.jl b/src/device/synchronization.jl
deleted file mode 100644
index b16d2518..00000000
--- a/src/device/synchronization.jl
+++ /dev/null
@@ -1,13 +0,0 @@
-# synchronization
-
-export synchronize_threads
-
-"""
-     synchronize_threads(ctx::AbstractKernelContext)
-
-in CUDA terms `__synchronize`
-in OpenCL terms: `barrier(CLK_LOCAL_MEM_FENCE)`
-"""
-function synchronize_threads(ctx::AbstractKernelContext)
-    error("Not implemented") # COV_EXCL_LINE
-end
diff --git a/src/host/abstractarray.jl b/src/host/abstractarray.jl
index ab0a5c4c..e30dbced 100644
--- a/src/host/abstractarray.jl
+++ b/src/host/abstractarray.jl
@@ -159,13 +159,12 @@ for (D, S) in ((AnyGPUArray, Array),
 end
 
 # kernel-based variant for copying between wrapped GPU arrays
-
-function linear_copy_kernel!(ctx::AbstractKernelContext, dest, dstart, src, sstart, n)
-    i = linear_index(ctx)-1
-    if i < n
-        @inbounds dest[dstart+i] = src[sstart+i]
+# TODO: Add `@Const` to `src`
+@kernel function linear_copy_kernel!(dest, dstart, src, sstart, n)
+    i = @index(Global, Linear)
+    if i <= n
+        @inbounds dest[dstart+i-1] = src[sstart+i-1]
     end
-    return
 end
 
 function Base.copyto!(dest::AnyGPUArray, dstart::Integer,
@@ -175,10 +174,8 @@ function Base.copyto!(dest::AnyGPUArray, dstart::Integer,
     destinds, srcinds = LinearIndices(dest), LinearIndices(src)
     (checkbounds(Bool, destinds, dstart) && checkbounds(Bool, destinds, dstart+n-1)) || throw(BoundsError(dest, dstart:dstart+n-1))
     (checkbounds(Bool, srcinds, sstart)  && checkbounds(Bool, srcinds, sstart+n-1))  || throw(BoundsError(src,  sstart:sstart+n-1))
-
-    gpu_call(linear_copy_kernel!,
-             dest, dstart, src, sstart, n;
-             elements=n)
+    kernel = linear_copy_kernel!(get_backend(dest))
+    kernel(dest, dstart, src, sstart, n; ndrange=n)
     return dest
 end
 
@@ -228,13 +225,9 @@ end
 
 ## generalized blocks of heterogeneous memory
 
-function cartesian_copy_kernel!(ctx::AbstractKernelContext, dest, dest_offsets, src, src_offsets, shape, length)
-    i = linear_index(ctx)
-    if i <= length
-        idx = CartesianIndices(shape)[i]
-        @inbounds dest[idx + dest_offsets] = src[idx + src_offsets]
-    end
-    return
+@kernel function cartesian_copy_kernel!(dest, dest_offsets, src, src_offsets)
+    I = @index(Global, Cartesian)
+    @inbounds dest[I + dest_offsets] = src[I + src_offsets]
 end
 
 function Base.copyto!(dest::AnyGPUArray{<:Any, N}, destcrange::CartesianIndices{N},
@@ -255,9 +248,8 @@ function Base.copyto!(dest::AnyGPUArray{<:Any, N}, destcrange::CartesianIndices{
 
     dest_offsets = first(destcrange) - oneunit(CartesianIndex{N})
     src_offsets = first(srccrange) - oneunit(CartesianIndex{N})
-    gpu_call(cartesian_copy_kernel!,
-             dest, dest_offsets, src, src_offsets, shape, len;
-             elements=len)
+    kernel = cartesian_copy_kernel!(get_backend(dest))
+    kernel(dest, dest_offsets, src, src_offsets; ndrange=shape)
     dest
 end
 
diff --git a/src/host/base.jl b/src/host/base.jl
index c5d1b0bb..eb381ebb 100644
--- a/src/host/base.jl
+++ b/src/host/base.jl
@@ -4,8 +4,7 @@ import Base: _RepeatInnerOuter
 # Handle `out = repeat(x; inner)` by parallelizing over `out` array This can benchmark
 # faster if repeating elements along the first axis (i.e. `inner=(n, ones...)`), as data
 # access can be contiguous on write.
-function repeat_inner_dst_kernel!(
-    ctx::AbstractKernelContext,
+@kernel function repeat_inner_dst_kernel!(
     xs::AbstractArray{<:Any, N},
     inner::NTuple{N, Int},
     out::AbstractArray{<:Any, N}
@@ -13,27 +12,25 @@ function repeat_inner_dst_kernel!(
     # Get the "stride" index in each dimension, where the size
     # of the stride is given by `inner`. The stride-index (sdx) then
     # corresponds to the index of the repeated value in `xs`.
-    odx = @cartesianidx out
+    odx = @index(Global, Cartesian)
     dest_inds = odx.I
     sdx = ntuple(N) do i
         @inbounds (dest_inds[i] - 1) ÷ inner[i] + 1
     end
     @inbounds out[odx] = xs[CartesianIndex(sdx)]
-    return nothing
 end
 
 # Handle `out = repeat(x; inner)` by parallelizing over the `xs` array This tends to
 # benchmark faster by having fewer read operations and avoiding the costly division
 # operation. Additionally, when repeating over the trailing dimension. `inner=(ones..., n)`,
 # data access can be contiguous during both the read and write operations.
-function repeat_inner_src_kernel!(
-    ctx::AbstractKernelContext,
+@kernel function repeat_inner_src_kernel!(
     xs::AbstractArray{<:Any, N},
     inner::NTuple{N, Int},
     out::AbstractArray{<:Any, N}
 ) where {N}
     # Get single element from src
-    idx = @cartesianidx xs
+    idx = @index(Global, Cartesian)
     @inbounds val = xs[idx]
 
     # Loop over "repeat" indices of inner
@@ -44,7 +41,6 @@ function repeat_inner_src_kernel!(
         end
         @inbounds out[CartesianIndex(odx)] = val
     end
-    return nothing
 end
 
 function repeat_inner(xs::AnyGPUArray, inner)
@@ -64,23 +60,24 @@ function repeat_inner(xs::AnyGPUArray, inner)
     # relevant benchmarks.
     if argmax(inner) == firstindex(inner)
         # Parallelize over the destination array
-        gpu_call(repeat_inner_dst_kernel!, xs, inner, out; elements=prod(size(out)))
+        kernel = repeat_inner_dst_kernel!(get_backend(out))
+        kernel(xs, inner, out; ndrange=size(out))
     else
         # Parallelize over the source array
-        gpu_call(repeat_inner_src_kernel!, xs, inner, out; elements=prod(size(xs)))
+        kernel = repeat_inner_src_kernel!(get_backend(xs))
+        kernel(xs, inner, out; ndrange=size(xs))
     end
     return out
 end
 
-function repeat_outer_kernel!(
-    ctx::AbstractKernelContext,
+@kernel function repeat_outer_kernel!(
     xs::AbstractArray{<:Any, N},
     xssize::NTuple{N},
     outer::NTuple{N},
     out::AbstractArray{<:Any, N}
 ) where {N}
     # Get index to input element
-    idx = @cartesianidx xs
+    idx = @index(Global, Cartesian)
     @inbounds val = xs[idx]
 
     # Loop over repeat indices, copying val to out
@@ -91,14 +88,13 @@ function repeat_outer_kernel!(
         end
         @inbounds out[CartesianIndex(odx)] = val
     end
-
-    return nothing
 end
 
 function repeat_outer(xs::AnyGPUArray, outer)
     out = similar(xs, eltype(xs), outer .* size(xs))
     any(==(0), size(out)) && return out # consistent with `Base.repeat`
-    gpu_call(repeat_outer_kernel!, xs, size(xs), outer, out; elements=length(xs))
+    kernel = repeat_outer_kernel!(get_backend(xs))
+    kernel(xs, size(xs), outer, out; ndrange=size(xs))
     return out
 end
 
diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl
index a78e6c7c..e8e01084 100644
--- a/src/host/broadcast.jl
+++ b/src/host/broadcast.jl
@@ -49,39 +49,26 @@ end
     isempty(dest) && return dest
     bc = Broadcast.preprocess(dest, bc)
 
-    broadcast_kernel = if ndims(dest) == 1 ||
-                          (isa(IndexStyle(dest), IndexLinear) &&
+    @kernel function broadcast_kernel_linear(dest, bc)
+        I = @index(Global, Linear)
+        @inbounds dest[I] = bc[I]
+    end
+
+    @kernel function broadcast_kernel_cartesian(dest, bc)
+        I = @index(Global, Cartesian)
+        @inbounds dest[I] = bc[I]
+    end
+
+    # grid-stride kernel, ndrange set for possible 0D evaluation
+    if ndims(dest) == 1 || (isa(IndexStyle(dest), IndexLinear) &&
                            isa(IndexStyle(bc), IndexLinear))
-        function (ctx, dest, bc, nelem)
-            i = 1
-            while i <= nelem
-                I = @linearidx(dest, i)
-                @inbounds dest[I] = bc[I]
-                i += 1
-            end
-            return
-        end
+        broadcast_kernel_linear(get_backend(dest))(dest, bc;
+            ndrange = length(size(dest)) > 0 ? length(dest) : 1)
     else
-        function (ctx, dest, bc, nelem)
-            i = 0
-            while i < nelem
-                i += 1
-                I = @cartesianidx(dest, i)
-                @inbounds dest[I] = bc[I]
-            end
-            return
-        end
+        broadcast_kernel_cartesian(get_backend(dest))(dest, bc;
+            ndrange = sz = length(size(dest)) > 0 ? size(dest) : (1,))
     end
 
-    elements = length(dest)
-    elements_per_thread = typemax(Int)
-    heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1;
-                                 elements, elements_per_thread)
-    config = launch_configuration(backend(dest), heuristic;
-                                  elements, elements_per_thread)
-    gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread;
-             threads=config.threads, blocks=config.blocks)
-
     if eltype(dest) <: BrokenBroadcast
         throw(ArgumentError("Broadcast operation resulting in $(eltype(eltype(dest))) is not GPU compatible"))
     end
@@ -130,27 +117,28 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
     end
 
     # grid-stride kernel
-    function map_kernel(ctx, dest, bc, nelem)
-        i = 1
-        while i <= nelem
-            j = linear_index(ctx, i)
-            j > common_length && return
+    @kernel function map_kernel(dest, bc, nelem, common_length)
 
-            J = CartesianIndices(axes(bc))[j]
-            @inbounds dest[j] = bc[J]
+        j = 0
+        J = @index(Global, Linear)
+        for i in 1:nelem
+            j += 1
+            if j <= common_length
 
-            i += 1
+                J_c = CartesianIndices(axes(bc))[(J-1)*nelem + j]
+                @inbounds dest[J_c] = bc[J_c]
+            end
         end
-        return
     end
     elements = common_length
     elements_per_thread = typemax(Int)
-    heuristic = launch_heuristic(backend(dest), map_kernel, dest, bc, 1;
-                                 elements, elements_per_thread)
-    config = launch_configuration(backend(dest), heuristic;
+    kernel = map_kernel(get_backend(dest))
+    heuristic = launch_heuristic(get_backend(dest), kernel, dest, bc, 1,
+                                 common_length; elements, elements_per_thread)
+    config = launch_configuration(get_backend(dest), heuristic;
                                   elements, elements_per_thread)
-    gpu_call(map_kernel, dest, bc, config.elements_per_thread;
-             threads=config.threads, blocks=config.blocks)
+    kernel(dest, bc, config.elements_per_thread,
+           common_length; ndrange = config.threads)
 
     if eltype(dest) <: BrokenBroadcast
         throw(ArgumentError("Map operation resulting in $(eltype(eltype(dest))) is not GPU compatible"))
diff --git a/src/host/construction.jl b/src/host/construction.jl
index d80bce2d..18a3b6d7 100644
--- a/src/host/construction.jl
+++ b/src/host/construction.jl
@@ -11,29 +11,33 @@ Base.convert(::Type{T}, a::AbstractArray) where {T<:AbstractGPUArray} = a isa T
 
 function Base.fill!(A::AnyGPUArray{T}, x) where T
     isempty(A) && return A
-    gpu_call(A, convert(T, x)) do ctx, a, val
-        idx = @linearidx(a)
+    @kernel function fill_kernel!(a, val)
+        idx = @index(Global, Linear)
         @inbounds a[idx] = val
-        return
     end
+
+    # ndrange set for a possible 0D evaluation
+    fill_kernel!(get_backend(A))(A, x,
+                                 ndrange = length(size(A)) > 0 ? size(A) : (1,))
     A
 end
 
 
 ## identity matrices
 
-function identity_kernel(ctx::AbstractKernelContext, res::AbstractArray{T}, stride, val) where T
-    i = linear_index(ctx)
+@kernel function identity_kernel(res::AbstractArray{T}, stride, val) where T
+    i = @index(Global, Linear)
     ilin = (stride * (i - 1)) + i
-    ilin > length(res) && return
-    @inbounds res[ilin] = val
-    return
+    if ilin <= length(res)
+        @inbounds res[ilin] = val
+    end
 end
 
 function (T::Type{<: AnyGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U}
     res = similar(T, dims)
     fill!(res, zero(U))
-    gpu_call(identity_kernel, res, size(res, 1), s.λ; elements=minimum(dims))
+    kernel = identity_kernel(get_backend(res))
+    kernel(res, size(res, 1), s.λ; ndrange=minimum(dims))
     res
 end
 
@@ -43,7 +47,8 @@ end
 
 function Base.copyto!(A::AbstractGPUMatrix{T}, s::UniformScaling) where T
     fill!(A, zero(T))
-    gpu_call(identity_kernel, A, size(A, 1), s.λ; elements=minimum(size(A)))
+    kernel = identity_kernel(get_backend(A))
+    kernel(A, size(A, 1), s.λ; ndrange=minimum(size(A)))
     A
 end
 
@@ -52,7 +57,8 @@ function _one(unit::T, x::AbstractGPUMatrix) where {T}
     m==n || throw(DimensionMismatch("multiplicative identity defined only for square matrices"))
     I = similar(x, T)
     fill!(I, zero(T))
-    gpu_call(identity_kernel, I, m, unit; elements=m)
+    kernel = identity_kernel(get_backend(I))
+    kernel(I, m, unit; ndrange=m)
     I
 end
 
diff --git a/src/host/indexing.jl b/src/host/indexing.jl
index 659fb029..4985d69d 100644
--- a/src/host/indexing.jl
+++ b/src/host/indexing.jl
@@ -72,7 +72,7 @@ end
     Is = map(adapt(ToGPU(dest)), Is)
     @boundscheck checkbounds(src, Is...)
 
-    gpu_call(getindex_kernel, dest, src, idims, Is...)
+    getindex_kernel(get_backend(dest))(dest, src, idims, Is...; ndrange=size(dest))
     return dest
 end
 
@@ -82,15 +82,19 @@ end
     return vectorized_getindex!(dest, src, Is...)
 end
 
-@generated function getindex_kernel(ctx::AbstractKernelContext, dest, src, idims,
-                                    Is::Vararg{Any,N}) where {N}
+@kernel function getindex_kernel(dest, src, idims,
+                                 Is::Vararg{Any,N}) where {N}
+    i = @index(Global, Linear)
+    getindex_generated(dest, src, idims, i, Is...)
+end
+
+@generated function getindex_generated(dest, src, idims, i,
+                                       Is::Vararg{Any,N}) where {N}
     quote
-        i = @linearidx dest
         is = @inbounds CartesianIndices(idims)[i]
         @nexprs $N i -> I_i = @inbounds(Is[i][is[i]])
         val = @ncall $N getindex src i -> I_i
         @inbounds dest[i] = val
-        return
     end
 end
 
@@ -111,15 +115,19 @@ end
     Is = map(adapt(ToGPU(dest)), Is)
     @boundscheck checkbounds(dest, Is...)
 
-    gpu_call(setindex_kernel, dest, adapt(ToGPU(dest), src), idims, len, Is...;
-             elements=len)
+    setindex_kernel(get_backend(dest))(dest, adapt(ToGPU(dest), src), idims, len, Is...;
+             ndrange = length(dest))
     return dest
 end
 
-@generated function setindex_kernel(ctx::AbstractKernelContext, dest, src, idims, len,
+@kernel function setindex_kernel(dest, src, idims, len,
                                     Is::Vararg{Any,N}) where {N}
+    i = @index(Global, Linear)
+    setindex_generated(dest, src, idims, len, i, Is...)
+end
+@generated function setindex_generated(dest, src, idims, len, i,
+                                       Is::Vararg{Any,N}) where {N}
     quote
-        i = linear_index(ctx)
         i > len && return
         is = @inbounds CartesianIndices(idims)[i]
         @nexprs $N i -> I_i = @inbounds(Is[i][is[i]])
diff --git a/src/host/linalg.jl b/src/host/linalg.jl
index fb23e42c..ad400ec0 100644
--- a/src/host/linalg.jl
+++ b/src/host/linalg.jl
@@ -14,20 +14,20 @@ function LinearAlgebra.transpose!(B::AbstractGPUMatrix, A::AbstractGPUVector)
 end
 function LinearAlgebra.adjoint!(B::AbstractGPUVector, A::AbstractGPUMatrix)
     axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("adjoint"))
-    gpu_call(B, A) do ctx, B, A
-        idx = @linearidx B
+    @kernel function adjoint_kernel!(B, A)
+        idx = @index(Global, Linear)
         @inbounds B[idx] = adjoint(A[1, idx])
-        return
     end
+    adjoint_kernel!(get_backend(B))(B, A, ndrange = size(B))
     B
 end
 function LinearAlgebra.adjoint!(B::AbstractGPUMatrix, A::AbstractGPUVector)
     axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("adjoint"))
-    gpu_call(B, A) do ctx, B, A
-        idx = @linearidx A
+    @kernel function adjoint_kernel!(B, A)
+        idx = @index(Global, Linear)
         @inbounds B[1, idx] = adjoint(A[idx])
-        return
     end
+    adjoint_kernel!(get_backend(A))(B, A, ndrange = size(A))
     B
 end
 
@@ -35,11 +35,11 @@ LinearAlgebra.transpose!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(transpos
 LinearAlgebra.adjoint!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(adjoint, B, A)
 function transpose_f!(f, B::AnyGPUMatrix{T}, A::AnyGPUMatrix{T}) where T
     axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) || throw(DimensionMismatch(string(f)))
-    gpu_call(B, A) do ctx, B, A
-        idx = @cartesianidx A
+    @kernel function transpose_kernel!(B, A)
+        idx = @index(Global, Cartesian)
         @inbounds B[idx[2], idx[1]] = f(A[idx[1], idx[2]])
-        return
     end
+    transpose_kernel!(get_backend(B))(B, A, ndrange = size(A))
     B
 end
 
@@ -60,48 +60,48 @@ end
 
 ## copy upper triangle to lower and vice versa
 
-function LinearAlgebra.copytri!(A::AbstractGPUMatrix, uplo::AbstractChar, conjugate::Bool=false)
-  n = LinearAlgebra.checksquare(A)
-  if uplo == 'U' && conjugate
-      gpu_call(A) do ctx, _A
-        I = @cartesianidx _A
-        i, j = Tuple(I)
-        if j > i
-          @inbounds _A[j,i] = conj(_A[i,j])
+function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, conjugate::Bool=false) where T
+    n = LinearAlgebra.checksquare(A)
+    if uplo == 'U' && conjugate
+        @kernel function U_conj!(_A)
+            I = @index(Global, Cartesian)
+            i, j = Tuple(I)
+            if j > i
+              @inbounds _A[j,i] = conj(_A[i,j])
+            end
         end
-        return
-      end
-  elseif uplo == 'U' && !conjugate
-      gpu_call(A) do ctx, _A
-        I = @cartesianidx _A
-        i, j = Tuple(I)
-        if j > i
-          @inbounds _A[j,i] = _A[i,j]
+        U_conj!(get_backend(A))(A, ndrange = size(A))
+    elseif uplo == 'U' && !conjugate
+        @kernel function U_noconj!(_A)
+            I = @index(Global, Cartesian)
+            i, j = Tuple(I)
+            if j > i
+              @inbounds _A[j,i] = _A[i,j]
+            end
         end
-        return
-      end
-  elseif uplo == 'L' && conjugate
-      gpu_call(A) do ctx, _A
-        I = @cartesianidx _A
-        i, j = Tuple(I)
-        if j > i
-          @inbounds _A[i,j] = conj(_A[j,i])
+        U_noconj!(get_backend(A))(A, ndrange = size(A))
+    elseif uplo == 'L' && conjugate
+        @kernel function L_conj!(_A)
+            I = @index(Global, Cartesian)
+            i, j = Tuple(I)
+            if j > i
+              @inbounds _A[i,j] = conj(_A[j,i])
+            end
         end
-        return
-      end
-  elseif uplo == 'L' && !conjugate
-      gpu_call(A) do ctx, _A
-        I = @cartesianidx _A
-        i, j = Tuple(I)
-        if j > i
-          @inbounds _A[i,j] = _A[j,i]
+        L_conj!(get_backend(A))(A, ndrange = size(A))
+    elseif uplo == 'L' && !conjugate
+        @kernel function L_noconj!(_A)
+            I = @index(Global, Cartesian)
+            i, j = Tuple(I)
+            if j > i
+              @inbounds _A[i,j] = _A[j,i]
+            end
         end
-        return
-      end
-  else
-      throw(ArgumentError("uplo argument must be 'U' (upper) or 'L' (lower), got $uplo"))
-  end
-  A
+        L_noconj!(get_backend(A))(A, ndrange = size(A))
+    else
+        throw(ArgumentError("uplo argument must be 'U' (upper) or 'L' (lower), got $uplo"))
+    end
+    A
 end
 
 ## copy a triangular part of a matrix to another matrix
@@ -113,23 +113,23 @@ if isdefined(LinearAlgebra, :copytrito!)
         m1,n1 = size(B)
         (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
         if uplo == 'U'
-            gpu_call(A, B) do ctx, _A, _B
-                I = @cartesianidx _A
+            @kernel function U_kernel!(_A, _B)
+                I = @index(Global, Cartesian)
                 i, j = Tuple(I)
                 if j >= i
                     @inbounds _B[i,j] = _A[i,j]
                 end
-                return
             end
+            U_kernel!(get_backend(B))(A, B, ndrange = size(A))
         else  # uplo == 'L'
-            gpu_call(A, B) do ctx, _A, _B
-                I = @cartesianidx _A
+            @kernel function L_kernel!(_A, _B)
+                I = @index(Global, Cartesian)
                 i, j = Tuple(I)
                 if j <= i
                     @inbounds _B[i,j] = _A[i,j]
                 end
-                return
             end
+            L_kernel!(get_backend(A))(A, B, ndrange = size(A))
         end
         return B
     end
@@ -149,26 +149,26 @@ for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriang
 end
 
 function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
-  gpu_call(A, d; name="tril!") do ctx, _A, _d
-    I = @cartesianidx _A
+  @kernel function tril_kernel!(_A, _d)
+    I = @index(Global, Cartesian)
     i, j = Tuple(I)
     if i < j - _d
       @inbounds _A[i, j] = zero(T)
     end
-    return
   end
+  tril_kernel!(get_backend(A))(A, d, ndrange = size(A))
   return A
 end
 
 function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
-  gpu_call(A, d; name="triu!") do ctx, _A, _d
-    I = @cartesianidx _A
+  @kernel function triu_kernel!(_A, _d)
+    I = @index(Global, Cartesian)
     i, j = Tuple(I)
     if j < i + _d
       @inbounds _A[i, j] = zero(T)
     end
-    return
   end
+  triu_kernel!(get_backend(A))(A, d, ndrange = size(A))
   return A
 end
 
@@ -330,9 +330,9 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
         return fill!(C, zero(R))
     end
 
-    gpu_call(C, A, B; name="matmatmul!") do ctx, C, A, B
-        idx = @linearidx C
+    @kernel function matmatmul_kernel!(C, A, B)
         assume.(size(C) .> 0)
+        idx = @index(Global, Linear)
         i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1
 
         @inbounds if i <= size(A,1) && j <= size(B,2)
@@ -343,10 +343,8 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
             end
             C[i,j] = add(Cij, C[i,j])
         end
-
-        return
     end
-
+    matmatmul_kernel!(get_backend(C))(C, A, B, ndrange = size(C))
     C
 end
 
@@ -382,8 +380,8 @@ function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun
     upper = tfun === identity ? uploc == 'U' :  uploc != 'U'
     unit  = isunitc == 'U'
 
-    function trimatmul(ctx, C, A, B)
-        idx = @linearidx C
+    @kernel function trimatmul(C, A, B)
+        idx = @index(Global, Linear)
         assume.(size(C) .> 0)
         i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1
         l, m, n = size(A, 1), size(B, 1), size(B, 2)
@@ -397,12 +395,10 @@ function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun
             end
             C[i,j] += Cij
         end
-
-        return
     end
 
-    function trimatmul_t(ctx, C, A, B)
-        idx = @linearidx C
+    @kernel function trimatmul_t(C, A, B)
+        idx = @index(Global, Linear)
         assume.(size(C) .> 0)
         i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1
         l, m, n = size(A, 1), size(B, 1), size(B, 2)
@@ -416,12 +412,10 @@ function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun
             end
             C[i,j] += Cij
         end
-
-        return
     end
 
-    function trimatmul_a(ctx, C, A, B)
-        idx = @linearidx C
+    @kernel function trimatmul_a(C, A, B)
+        idx = @index(Global, Linear)
         assume.(size(C) .> 0)
         i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1
         l, m, n = size(A, 1), size(B, 1), size(B, 2)
@@ -435,16 +429,14 @@ function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun
             end
             C[i,j] += Cij
         end
-
-        return
     end
 
     if tfun === identity
-        gpu_call(trimatmul, C, A, B; name="trimatmul")
+        trimatmul(get_backend(C))(C, A, B, ndrange = length(C))
     elseif tfun == transpose
-        gpu_call(trimatmul_t, C, A, B; name="trimatmul_t")
+        trimatmul_t(get_backend(C))(C, A, B, ndrange = length(C))
     elseif tfun === adjoint
-        gpu_call(trimatmul_a, C, A, B; name="trimatmul_a")
+        trimatmul_a(get_backend(C))(C, A, B, ndrange = length(C))
     else
         error("Not supported")
     end
@@ -466,8 +458,8 @@ function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun
     upper = tfun === identity ? uploc == 'U' :  uploc != 'U'
     unit  = isunitc == 'U'
 
-    function mattrimul(ctx, C, A, B)
-        idx = @linearidx C
+    @kernel function mattrimul(C, A, B)
+        idx = @index(Global, Linear)
         assume.(size(C) .> 0)
         i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1
         l, m, n = size(A, 1), size(B, 1), size(B, 2)
@@ -481,12 +473,10 @@ function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun
             end
             C[i,j] += Cij
         end
-
-        return
     end
 
-    function mattrimul_t(ctx, C, A, B)
-        idx = @linearidx C
+    @kernel function mattrimul_t(C, A, B)
+        idx = @index(Global, Linear)
         assume.(size(C) .> 0)
         i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1
         l, m, n = size(A, 1), size(B, 1), size(B, 2)
@@ -500,12 +490,10 @@ function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun
             end
             C[i,j] += Cij
         end
-
-        return
     end
 
-    function mattrimul_a(ctx, C, A, B)
-        idx = @linearidx C
+    @kernel function mattrimul_a(C, A, B)
+        idx = @index(Global, Linear)
         assume.(size(C) .> 0)
         i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1
         l, m, n = size(A, 1), size(B, 1), size(B, 2)
@@ -519,16 +507,14 @@ function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun
             end
             C[i,j] += Cij
         end
-
-        return
     end
 
     if tfun === identity
-        gpu_call(mattrimul, C, A, B; name="mattrimul")
+        mattrimul(get_backend(C))(C, A, B, ndrange = length(C))
     elseif tfun == transpose
-        gpu_call(mattrimul_t, C, A, B; name="mattrimul_t")
+        mattrimul_t(get_backend(C))(C, A, B, ndrange = length(C))
     elseif tfun === adjoint
-        gpu_call(mattrimul_a, C, A, B; name="mattrimul_a")
+        mattrimul_a(get_backend(C))(C, A, B, ndrange = length(C))
     else
         error("Not supported")
     end
@@ -544,22 +530,22 @@ function LinearAlgebra.generic_mattrimul!(C::AbstractGPUMatrix, uploc, isunitc,
 end
 
 function generic_rmul!(X::AbstractArray, s::Number)
-    gpu_call(X, s; name="rmul!") do ctx, X, s
-        i = @linearidx X
+    @kernel function rmul_kernel!(X, s)
+        i = @index(Global, Linear)
         @inbounds X[i] *= s
-        return
     end
+    rmul_kernel!(get_backend(X))(X, s, ndrange = size(X))
     return X
 end
 
 LinearAlgebra.rmul!(A::AbstractGPUArray, b::Number) = generic_rmul!(A, b)
 
 function generic_lmul!(s::Number, X::AbstractArray)
-    gpu_call(X, s; name="lmul!") do ctx, X, s
-        i = @linearidx X
+    @kernel function lmul_kernel!(X, s)
+        i = @index(Global, Linear)
         @inbounds X[i] = s*X[i]
-        return
     end
+    lmul_kernel!(get_backend(X))(X, s, ndrange = size(X))
     return X
 end
 
@@ -601,15 +587,16 @@ function _permutedims!(::Type{IT}, dest::AbstractGPUArray,
     dest_strides = ntuple(k->k==1 ? 1 : prod(i->size(dest, i), 1:k-1), N)
     dest_strides_perm = ntuple(i->IT(dest_strides[findfirst(==(i), perm)]), N)
     size_src = IT.(size(src))
-    function permutedims_kernel(ctx, dest, src, size_src, dest_strides_perm)
-        SLI = @linearidx dest
+    @kernel function permutedims_kernel!(dest, src, size_src, dest_strides_perm)
+        SLI = @index(Global, Linear)
         assume(0 < SLI <= typemax(IT))
         LI = IT(SLI)
         dest_index = permute_linearindex(size_src, LI, dest_strides_perm)
         @inbounds dest[dest_index] = src[LI]
-        return
     end
-    gpu_call(permutedims_kernel, vec(dest), vec(src), size_src, dest_strides_perm)
+    permutedims_kernel!(get_backend(dest))(vec(dest), vec(src), size_src,
+                                           dest_strides_perm,
+                                           ndrange = size(dest))
     return dest
 end
 
@@ -686,28 +673,28 @@ end
 ## rotate
 
 function LinearAlgebra.rotate!(x::AbstractGPUArray, y::AbstractGPUArray, c::Number, s::Number)
-    gpu_call(x, y, c, s; name="rotate!") do ctx, x, y, c, s
-        i = @linearidx x
+    @kernel function rotate_kernel!(x, y, c, s)
+        i = @index(Global, Linear)
         @inbounds xi = x[i]
         @inbounds yi = y[i]
         @inbounds x[i] =       c  * xi + s * yi
         @inbounds y[i] = -conj(s) * xi + c * yi
-        return
     end
+    rotate_kernel!(get_backend(x))(x, y, c, s, ndrange = size(x))
     return x, y
 end
 
 ## reflect
 
 function LinearAlgebra.reflect!(x::AbstractGPUArray, y::AbstractGPUArray, c::Number, s::Number)
-    gpu_call(x, y, c, s; name="reflect!") do ctx, x, y, c, s
-        i = @linearidx x
+    @kernel function  reflect_kernel!(x, y, c, s)
+        i = @index(Global, Linear)
         @inbounds xi = x[i]
         @inbounds yi = y[i]
         @inbounds x[i] =      c  * xi + s * yi
         @inbounds y[i] = conj(s) * xi - c * yi
-        return
     end
+    reflect_kernel!(get_backend(x))(x, y, c, s, ndrange = size(x))
     return x, y
 end
 
diff --git a/src/host/math.jl b/src/host/math.jl
index cf455d31..f96fb8ed 100644
--- a/src/host/math.jl
+++ b/src/host/math.jl
@@ -1,10 +1,10 @@
 # Base mathematical operations
 
 function Base.clamp!(A::AnyGPUArray, low, high)
-    gpu_call(A, low, high) do ctx, A, low, high
-        I = @linearidx A
+    @kernel function clamp_kernel!(A, low, high)
+        I = @index(Global, Cartesian)
         A[I] = clamp(A[I], low, high)
-        return
     end
+    clamp_kernel!(get_backend(A))(A, low, high, ndrange = size(A))
     return A
 end
diff --git a/src/host/random.jl b/src/host/random.jl
index b7e5dc74..2112a8ed 100644
--- a/src/host/random.jl
+++ b/src/host/random.jl
@@ -30,15 +30,13 @@ function next_rand(state::NTuple{4, T}) where {T <: Unsigned}
     return state, tmp
 end
 
-function gpu_rand(::Type{T}, ctx::AbstractKernelContext, randstate::AbstractVector{NTuple{4, UInt32}}) where T
-    threadid = GPUArrays.threadidx(ctx)
+function gpu_rand(::Type{T}, threadid, randstate::AbstractVector{NTuple{4, UInt32}}) where T
     stateful_rand = next_rand(randstate[threadid])
     randstate[threadid] = stateful_rand[1]
     return make_rand_num(T, stateful_rand[2])
 end
 
-function gpu_rand(::Type{T}, ctx::AbstractKernelContext, randstate::AbstractVector{NTuple{4, UInt32}}) where T <: Integer
-    threadid = GPUArrays.threadidx(ctx)
+function gpu_rand(::Type{T}, threadid, randstate::AbstractVector{NTuple{4, UInt32}}) where T <: Integer
     result = zero(T)
     if sizeof(T) >= 4
         for _ in 1:sizeof(T) >> 2
@@ -55,9 +53,9 @@ end
 
 # support for complex numbers
 
-function gpu_rand(::Type{Complex{T}}, ctx::AbstractKernelContext, randstate::AbstractVector{NTuple{4, UInt32}}) where T
-    re = gpu_rand(T, ctx, randstate)
-    im = gpu_rand(T, ctx, randstate)
+function gpu_rand(::Type{Complex{T}}, threadid, randstate::AbstractVector{NTuple{4, UInt32}}) where T
+    re = gpu_rand(T, threadid, randstate)
+    im = gpu_rand(T, threadid, randstate)
     return complex(re, im)
 end
 
@@ -85,29 +83,31 @@ end
 
 function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
     isempty(A) && return A
-    gpu_call(A, rng.state) do ctx, a, randstates
-        idx = linear_index(ctx)
-        idx > length(a) && return
-        @inbounds a[idx] = gpu_rand(T, ctx, randstates)
-        return
+    @kernel function rand!(a, randstate)
+        idx = @index(Global, Linear)
+        @inbounds a[idx] = gpu_rand(T, ((idx-1)%length(randstate)+1), randstate)
     end
+    rand!(get_backend(A))(A, rng.state, ndrange = size(A))
     A
 end
 
 function Random.randn!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
     isempty(A) && return A
     threads = (length(A) - 1) ÷ 2 + 1
-    gpu_call(A, rng.state; elements = threads) do ctx, a, randstates
-        idx = 2*(linear_index(ctx) - 1) + 1
-        U1 = gpu_rand(T, ctx, randstates)
-        U2 = gpu_rand(T, ctx, randstates)
+    @kernel function randn!(a, randstates)
+        i = @index(Global, Linear) 
+        idx = 2*(i - 1) + 1
+        U1 = gpu_rand(T, i, randstates)
+        U2 = gpu_rand(T, i, randstates)
         Z0 = sqrt(T(-2.0)*log(U1))*cos(T(2pi)*U2)
         Z1 = sqrt(T(-2.0)*log(U1))*sin(T(2pi)*U2)
         @inbounds a[idx] = Z0
-        idx + 1 > length(a) && return
-        @inbounds a[idx + 1] = Z1
-        return
+        if idx + 1 <= length(a)
+            @inbounds a[idx + 1] = Z1
+        end
     end
+    kernel = randn!(get_backend(A))
+    kernel(A, rng.state; ndrange=threads)
     A
 end
 
diff --git a/src/host/uniformscaling.jl b/src/host/uniformscaling.jl
index 848eef5e..f8f8ae5a 100644
--- a/src/host/uniformscaling.jl
+++ b/src/host/uniformscaling.jl
@@ -12,20 +12,16 @@ const unittriangularwrappers = (
     (:UnitLowerTriangular, :LowerTriangular)
 )
 
-function kernel_generic(ctx, B, J, min_size)
-    lin_idx = linear_index(ctx)
-    lin_idx > min_size && return nothing
+@kernel function kernel_generic(B, J)
+    lin_idx = @index(Global, Linear)
     @inbounds diag_idx = diagind(B)[lin_idx]
     @inbounds B[diag_idx] += J
-    return nothing
 end
 
-function kernel_unittriangular(ctx, B, J, diagonal_val, min_size)
-    lin_idx = linear_index(ctx)
-    lin_idx > min_size && return nothing
+@kernel function kernel_unittriangular(B, J, diagonal_val)
+    lin_idx = @index(Global, Linear)
     @inbounds diag_idx = diagind(B)[lin_idx]
     @inbounds B[diag_idx] = diagonal_val + J
-    return nothing
 end
 
 for (t1, t2) in unittriangularwrappers
@@ -34,7 +30,7 @@ for (t1, t2) in unittriangularwrappers
             B = similar(parent(A), typeof(oneunit(T) + J))
             copyto!(B, parent(A))
             min_size = minimum(size(B))
-            gpu_call(kernel_unittriangular, B, J, one(eltype(B)), min_size; elements=min_size)
+            kernel_unittriangular(get_backend(B))(B, J, one(eltype(B)); ndrange=min_size)
             return $t2(B)
         end
 
@@ -42,7 +38,7 @@ for (t1, t2) in unittriangularwrappers
             B = similar(parent(A), typeof(J - oneunit(T)))
             B .= .- parent(A)
             min_size = minimum(size(B))
-            gpu_call(kernel_unittriangular, B, J, -one(eltype(B)), min_size; elements=min_size)
+            kernel_unittriangular(get_backend(B))(B, J, -one(eltype(B)); ndrange=min_size)
             return $t2(B)
         end
     end
@@ -54,7 +50,7 @@ for t in genericwrappers
             B = similar(parent(A), typeof(oneunit(T) + J))
             copyto!(B, parent(A))
             min_size = minimum(size(B))
-            gpu_call(kernel_generic, B, J, min_size; elements=min_size)
+            kernel_generic(get_backend(B))(B, J; ndrange=min_size)
             return $t(B)
         end
 
@@ -62,7 +58,7 @@ for t in genericwrappers
             B = similar(parent(A), typeof(J - oneunit(T)))
             B .= .- parent(A)
             min_size = minimum(size(B))
-            gpu_call(kernel_generic, B, J, min_size; elements=min_size)
+            kernel_generic(get_backend(B))(B, J; ndrange=min_size)
             return $t(B)
         end
     end
@@ -73,7 +69,7 @@ function (+)(A::Hermitian{T,<:AbstractGPUMatrix}, J::UniformScaling{<:Complex})
     B = similar(parent(A), typeof(oneunit(T) + J))
     copyto!(B, parent(A))
     min_size = minimum(size(B))
-    gpu_call(kernel_generic, B, J, min_size; elements=min_size)
+    kernel_generic(get_backend(B))(B, J; ndrange=min_size)
     return B
 end
 
@@ -81,7 +77,7 @@ function (-)(J::UniformScaling{<:Complex}, A::Hermitian{T,<:AbstractGPUMatrix})
     B = similar(parent(A), typeof(J - oneunit(T)))
     B .= .-parent(A)
     min_size = minimum(size(B))
-    gpu_call(kernel_generic, B, J, min_size; elements=min_size)
+    kernel_generic(get_backend(B))(B, J; ndrange=min_size)
     return B
 end
 
@@ -90,7 +86,7 @@ function (+)(A::AbstractGPUMatrix{T}, J::UniformScaling) where T
     B = similar(A, typeof(oneunit(T) + J))
     copyto!(B, A)
     min_size = minimum(size(B))
-    gpu_call(kernel_generic, B, J, min_size; elements=min_size)
+    kernel_generic(get_backend(B))(B, J; ndrange=min_size)
     return B
 end
 
@@ -98,6 +94,6 @@ function (-)(J::UniformScaling, A::AbstractGPUMatrix{T}) where T
     B = similar(A, typeof(J - oneunit(T)))
     B .= .-A
     min_size = minimum(size(B))
-    gpu_call(kernel_generic, B, J, min_size; elements=min_size)
+    kernel_generic(get_backend(B))(B, J; ndrange=min_size)
     return B
 end
diff --git a/test/Project.toml b/test/Project.toml
index 76e1e22a..eb59ac76 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -1,6 +1,7 @@
 [deps]
 Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
 JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
+KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
diff --git a/test/runtests.jl b/test/runtests.jl
index 4df72b2b..2cfdb0c6 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,6 +2,22 @@ using GPUArrays, Test, Pkg
 
 include("testsuite.jl")
 
+@testset "JLArray" begin
+    # install the JLArrays subpackage in a temporary environment
+    old_project = Base.active_project()
+    Pkg.activate(; temp=true)
+    Pkg.develop(path=joinpath(dirname(@__DIR__), "lib", "JLArrays"))
+
+    using JLArrays
+
+    jl([1])
+
+    TestSuite.test(JLArray)
+
+    Pkg.activate(old_project)
+end
+
+#=
 @testset "JLArray" begin
     using JLArrays
 
@@ -9,6 +25,7 @@ include("testsuite.jl")
 
     TestSuite.test(JLArray)
 end
+=#
 
 @testset "Array" begin
     TestSuite.test(Array)
diff --git a/test/testsuite.jl b/test/testsuite.jl
index e7c14646..b939d2e9 100644
--- a/test/testsuite.jl
+++ b/test/testsuite.jl
@@ -8,6 +8,7 @@ export supported_eltypes
 
 using GPUArrays
 
+using KernelAbstractions
 using LinearAlgebra
 using Random
 using Test
@@ -85,7 +86,6 @@ macro testsuite(name, ex)
 end
 
 include("testsuite/construction.jl")
-include("testsuite/gpuinterface.jl")
 include("testsuite/indexing.jl")
 include("testsuite/base.jl")
 include("testsuite/vector.jl")
diff --git a/test/testsuite/base.jl b/test/testsuite/base.jl
index 5166654d..0e134398 100644
--- a/test/testsuite/base.jl
+++ b/test/testsuite/base.jl
@@ -1,28 +1,23 @@
-function cartesian_iter(state, res, A, Asize)
-    for i in CartesianIndices(Asize)
-        res[i] = A[i]
-    end
-    return
+@kernel function cartesian_iter(res, A)
+    i = @index(Global, Cartesian)
+    res[i] = A[i]
 end
 
-function clmap!(ctx, f, out, b)
-    i = linear_index(ctx) # get the kernel index it gets scheduled on
+@kernel function clmap!(f, out, b)
+    i = @index(Global, Linear) # get the kernel index it gets scheduled on
     out[i] = f(b[i])
-    return
 end
 
-function ntuple_test(ctx, result, ::Val{N}) where N
+@kernel function ntuple_test(result, ::Val{N}) where N
     result[1] = ntuple(Val(N)) do i
         Float32(i) * 77f0
     end
-    return
 end
 
-function ntuple_closure(ctx, result, ::Val{N}, testval) where N
+@kernel function ntuple_closure(result, ::Val{N}, testval) where N
     result[1] = ntuple(Val(N)) do i
         Float32(i) * testval
     end
-    return
 end
 
 @testsuite "base" (AT, eltypes)->begin
@@ -191,10 +186,10 @@ end
 
     AT <: AbstractGPUArray && @testset "ntuple test" begin
         result = AT(Vector{NTuple{3, Float32}}(undef, 1))
-        gpu_call(ntuple_test, result, Val(3))
+        ntuple_test(get_backend(result))(result, Val(3); ndrange = 1)
         @test Array(result)[1] == (77, 2*77, 3*77)
         x = 88f0
-        gpu_call(ntuple_closure, result, Val(3), x)
+        ntuple_closure(get_backend(result))(result, Val(3), x; ndrange = 1)
         @test Array(result)[1] == (x, 2*x, 3*x)
     end
 
@@ -202,14 +197,14 @@ end
         Ac = rand(Float32, 32, 32)
         A = AT(Ac)
         result = fill!(copy(A), 0.0f0)
-        gpu_call(cartesian_iter, result, A, size(A))
+        cartesian_iter(get_backend(A))(result, A; ndrange = size(A))
         Array(result) == Ac
     end
 
     AT <: AbstractGPUArray && @testset "Custom kernel from Julia function" begin
         x = AT(rand(Float32, 100))
         y = AT(rand(Float32, 100))
-        gpu_call(clmap!, -, x, y; target=x)
+        clmap!(get_backend(x))(-, x, y; ndrange = size(x))
         jy = Array(y)
         @test map!(-, jy, jy) ≈ Array(x)
     end
diff --git a/test/testsuite/broadcasting.jl b/test/testsuite/broadcasting.jl
index b856eb0f..81b028f3 100644
--- a/test/testsuite/broadcasting.jl
+++ b/test/testsuite/broadcasting.jl
@@ -200,8 +200,9 @@ Base.size(A::WrapArray) = size(A.data)
 # For kernal support
 Adapt.adapt_structure(to, s::WrapArray) = WrapArray(Adapt.adapt(to, s.data))
 # For broadcast support
-GPUArrays.backend(::Type{WrapArray{T,N,P}}) where {T,N,P} = GPUArrays.backend(P)
 Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P)
+KernelAbstractions.get_backend(a::WA) where WA <: WrapArray = get_backend(a.data)
+
 
 function unknown_wrapper(AT, eltypes)
     for ET in eltypes
diff --git a/test/testsuite/gpuinterface.jl b/test/testsuite/gpuinterface.jl
deleted file mode 100644
index 1455c732..00000000
--- a/test/testsuite/gpuinterface.jl
+++ /dev/null
@@ -1,47 +0,0 @@
-@testsuite "interface" (AT, eltypes)->begin
-    AT <: AbstractGPUArray || return
-
-    N = 10
-    x = AT(Vector{Int}(undef, N))
-    x .= 0
-    gpu_call(x) do ctx, x
-        x[linear_index(ctx)] = 2
-        return
-    end
-    @test all(x-> x == 2, Array(x))
-
-    gpu_call(x; elements=N) do ctx, x
-        x[linear_index(ctx)] = 2
-        return
-    end
-    @test all(x-> x == 2, Array(x))
-    gpu_call(x; threads=2, blocks=(N ÷ 2)) do ctx, x
-        x[linear_index(ctx)] = threadidx(ctx)
-        return
-    end
-    @test Array(x) == [1,2,1,2,1,2,1,2,1,2]
-
-    gpu_call(x; threads=2, blocks=(N ÷ 2)) do ctx, x
-        x[linear_index(ctx)] = blockidx(ctx)
-        return
-    end
-    @test Array(x) == [1, 1, 2, 2, 3, 3, 4, 4, 5, 5]
-    x2 = AT([0])
-    gpu_call(x2; threads=2, blocks=(N ÷ 2), target=x) do ctx, x
-        x[1] = blockdim(ctx)
-        return
-    end
-    @test Array(x2) == [2]
-
-    gpu_call(x2; threads=2, blocks=(N ÷ 2), target=x) do ctx, x
-        x[1] = griddim(ctx)
-        return
-    end
-    @test Array(x2) == [5]
-
-    gpu_call(x2; threads=2, blocks=(N ÷ 2), target=x) do ctx, x
-        x[1] = global_size(ctx)
-        return
-    end
-    @test Array(x2) == [10]
-end

From 00c8dd4912c5d1c4c4260f0b8baecf71647d52ad Mon Sep 17 00:00:00 2001
From: James Schloss <jrs.schloss@gmail.com>
Date: Thu, 25 Jul 2024 15:25:38 +0200
Subject: [PATCH 2/4] remocing heuristic

---
 src/GPUArrays.jl        |  1 -
 src/device/execution.jl | 39 ---------------------------------------
 src/host/broadcast.jl   | 26 ++++++--------------------
 3 files changed, 6 insertions(+), 60 deletions(-)
 delete mode 100644 src/device/execution.jl

diff --git a/src/GPUArrays.jl b/src/GPUArrays.jl
index 54e9c877..eb8df141 100644
--- a/src/GPUArrays.jl
+++ b/src/GPUArrays.jl
@@ -16,7 +16,6 @@ using Reexport
 @reexport using GPUArraysCore
 
 ## executed on-device
-include("device/execution.jl")
 include("device/abstractarray.jl")
 
 using KernelAbstractions
diff --git a/src/device/execution.jl b/src/device/execution.jl
deleted file mode 100644
index 64a81dad..00000000
--- a/src/device/execution.jl
+++ /dev/null
@@ -1,39 +0,0 @@
-# kernel execution
-
-# how many threads and blocks `kernel` needs to be launched with, passing arguments `args`,
-# to fully saturate the GPU. `elements` indicates the number of elements that needs to be
-# processed, while `elements_per_threads` indicates the number of elements this kernel can
-# process (i.e. if it's a grid-stride kernel, or 1 if otherwise).
-#
-# this heuristic should be specialized for the back-end, ideally using an API for maximizing
-# the occupancy of the launch configuration (like CUDA's occupancy API).
-function launch_heuristic(backend::B, kernel, args...;
-                          elements::Int,
-                          elements_per_thread::Int) where B <: Backend
-    return (threads=256, blocks=32)
-end
-
-# determine how many threads and blocks to actually launch given upper limits.
-# returns a tuple of blocks, threads, and elements_per_thread (which is always 1
-# unless specified that the kernel can handle a number of elements per thread)
-function launch_configuration(backend::B, heuristic;
-                              elements::Int,
-                              elements_per_thread::Int) where B <: Backend
-    threads = clamp(elements, 1, heuristic.threads)
-    blocks = max(cld(elements, threads), 1)
-
-    if elements_per_thread > 1 && blocks > heuristic.blocks
-        # we want to launch more blocks than required, so prefer a grid-stride loop instead
-        ## try to stick to the number of blocks that the heuristic suggested
-        blocks = heuristic.blocks
-        nelem = cld(elements, blocks*threads)
-        ## only bump the number of blocks if we really need to
-        if nelem > elements_per_thread
-            nelem = elements_per_thread
-            blocks = cld(elements, nelem*threads)
-        end
-        (; threads, blocks, elements_per_thread=nelem)
-    else
-        (; threads, blocks, elements_per_thread=1)
-    end
-end
diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl
index e8e01084..44c4bf11 100644
--- a/src/host/broadcast.jl
+++ b/src/host/broadcast.jl
@@ -117,28 +117,14 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
     end
 
     # grid-stride kernel
-    @kernel function map_kernel(dest, bc, nelem, common_length)
-
-        j = 0
-        J = @index(Global, Linear)
-        for i in 1:nelem
-            j += 1
-            if j <= common_length
-
-                J_c = CartesianIndices(axes(bc))[(J-1)*nelem + j]
-                @inbounds dest[J_c] = bc[J_c]
-            end
-        end
+    @kernel function map_kernel(dest, bc)
+        j = @index(Global, Linear)
+        @inbounds dest[j] = bc[j]
     end
-    elements = common_length
-    elements_per_thread = typemax(Int)
+
     kernel = map_kernel(get_backend(dest))
-    heuristic = launch_heuristic(get_backend(dest), kernel, dest, bc, 1,
-                                 common_length; elements, elements_per_thread)
-    config = launch_configuration(get_backend(dest), heuristic;
-                                  elements, elements_per_thread)
-    kernel(dest, bc, config.elements_per_thread,
-           common_length; ndrange = config.threads)
+    config = KernelAbstractions.launch_config(kernel, common_length, nothing)
+    kernel(dest, bc; ndrange = config[1], workgroupsize = config[2])
 
     if eltype(dest) <: BrokenBroadcast
         throw(ArgumentError("Map operation resulting in $(eltype(eltype(dest))) is not GPU compatible"))

From 52db290176b2f38b1b021911b6afaaf231b0ce9c Mon Sep 17 00:00:00 2001
From: James Schloss <jrs.schloss@gmail.com>
Date: Mon, 16 Sep 2024 13:04:07 +0200
Subject: [PATCH 3/4] Revert "remocing heuristic"

This reverts commit 0c7e26b2ea212bb2951ce9694a684f3de9814819.
---
 src/GPUArrays.jl        |  1 +
 src/device/execution.jl | 39 +++++++++++++++++++++++++++++++++++++++
 src/host/broadcast.jl   | 26 ++++++++++++++++++++------
 3 files changed, 60 insertions(+), 6 deletions(-)
 create mode 100644 src/device/execution.jl

diff --git a/src/GPUArrays.jl b/src/GPUArrays.jl
index eb8df141..54e9c877 100644
--- a/src/GPUArrays.jl
+++ b/src/GPUArrays.jl
@@ -16,6 +16,7 @@ using Reexport
 @reexport using GPUArraysCore
 
 ## executed on-device
+include("device/execution.jl")
 include("device/abstractarray.jl")
 
 using KernelAbstractions
diff --git a/src/device/execution.jl b/src/device/execution.jl
new file mode 100644
index 00000000..64a81dad
--- /dev/null
+++ b/src/device/execution.jl
@@ -0,0 +1,39 @@
+# kernel execution
+
+# how many threads and blocks `kernel` needs to be launched with, passing arguments `args`,
+# to fully saturate the GPU. `elements` indicates the number of elements that needs to be
+# processed, while `elements_per_threads` indicates the number of elements this kernel can
+# process (i.e. if it's a grid-stride kernel, or 1 if otherwise).
+#
+# this heuristic should be specialized for the back-end, ideally using an API for maximizing
+# the occupancy of the launch configuration (like CUDA's occupancy API).
+function launch_heuristic(backend::B, kernel, args...;
+                          elements::Int,
+                          elements_per_thread::Int) where B <: Backend
+    return (threads=256, blocks=32)
+end
+
+# determine how many threads and blocks to actually launch given upper limits.
+# returns a tuple of blocks, threads, and elements_per_thread (which is always 1
+# unless specified that the kernel can handle a number of elements per thread)
+function launch_configuration(backend::B, heuristic;
+                              elements::Int,
+                              elements_per_thread::Int) where B <: Backend
+    threads = clamp(elements, 1, heuristic.threads)
+    blocks = max(cld(elements, threads), 1)
+
+    if elements_per_thread > 1 && blocks > heuristic.blocks
+        # we want to launch more blocks than required, so prefer a grid-stride loop instead
+        ## try to stick to the number of blocks that the heuristic suggested
+        blocks = heuristic.blocks
+        nelem = cld(elements, blocks*threads)
+        ## only bump the number of blocks if we really need to
+        if nelem > elements_per_thread
+            nelem = elements_per_thread
+            blocks = cld(elements, nelem*threads)
+        end
+        (; threads, blocks, elements_per_thread=nelem)
+    else
+        (; threads, blocks, elements_per_thread=1)
+    end
+end
diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl
index 44c4bf11..e8e01084 100644
--- a/src/host/broadcast.jl
+++ b/src/host/broadcast.jl
@@ -117,14 +117,28 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
     end
 
     # grid-stride kernel
-    @kernel function map_kernel(dest, bc)
-        j = @index(Global, Linear)
-        @inbounds dest[j] = bc[j]
+    @kernel function map_kernel(dest, bc, nelem, common_length)
+
+        j = 0
+        J = @index(Global, Linear)
+        for i in 1:nelem
+            j += 1
+            if j <= common_length
+
+                J_c = CartesianIndices(axes(bc))[(J-1)*nelem + j]
+                @inbounds dest[J_c] = bc[J_c]
+            end
+        end
     end
-
+    elements = common_length
+    elements_per_thread = typemax(Int)
     kernel = map_kernel(get_backend(dest))
-    config = KernelAbstractions.launch_config(kernel, common_length, nothing)
-    kernel(dest, bc; ndrange = config[1], workgroupsize = config[2])
+    heuristic = launch_heuristic(get_backend(dest), kernel, dest, bc, 1,
+                                 common_length; elements, elements_per_thread)
+    config = launch_configuration(get_backend(dest), heuristic;
+                                  elements, elements_per_thread)
+    kernel(dest, bc, config.elements_per_thread,
+           common_length; ndrange = config.threads)
 
     if eltype(dest) <: BrokenBroadcast
         throw(ArgumentError("Map operation resulting in $(eltype(eltype(dest))) is not GPU compatible"))

From c2bd9f46f9412c99d499f0dd4723415a6b20e51e Mon Sep 17 00:00:00 2001
From: James Schloss <jrs.schloss@gmail.com>
Date: Mon, 16 Sep 2024 13:48:18 +0200
Subject: [PATCH 4/4] recommenting some stuff