Skip to content

Commit 367c78b

Browse files
committed
add N for hardware indices
1 parent eebd4b8 commit 367c78b

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/KernelAbstractions.jl

+11-9
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ using StaticArrays
1717
using Adapt
1818

1919
"""
20-
@kernel function f(args) end
20+
@kernel [N] function f(args) end
2121
2222
Takes a function definition and generates a [`Kernel`](@ref) constructor from it.
2323
The enclosed function is allowed to contain kernel language constructs.
2424
In order to call it the kernel has first to be specialized on the backend
2525
and then invoked on the arguments.
2626
27+
The optional `N` parameter can be used to fix the number of dimensions used for the ndrange.
28+
2729
# Kernel language
2830
2931
- [`@Const`](@ref)
@@ -55,7 +57,7 @@ macro kernel(expr)
5557
end
5658

5759
"""
58-
@kernel config function f(args) end
60+
@kernel [N] config function f(args) end
5961
6062
This allows for two different configurations:
6163
@@ -585,17 +587,17 @@ in a workgroup.
585587
```
586588
As well as the on-device functionality.
587589
"""
588-
struct Kernel{Backend, WorkgroupSize <: _Size, NDRange <: _Size, Fun}
590+
struct Kernel{Backend, N, WorkgroupSize <: _Size, NDRange <: _Size, Fun}
589591
backend::Backend
590592
f::Fun
591593
end
592594

593-
function Base.similar(kernel::Kernel{D, WS, ND}, f::F) where {D, WS, ND, F}
594-
Kernel{D, WS, ND, F}(kernel.backend, f)
595+
function Base.similar(kernel::Kernel{D, N, WS, ND}, f::F) where {D, N, WS, ND, F}
596+
Kernel{D, N, WS, ND, F}(kernel.backend, f)
595597
end
596598

597-
workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize
598-
ndrange(::Kernel{D, WorkgroupSize, NDRange}) where {D, WorkgroupSize, NDRange} = NDRange
599+
workgroupsize(::Kernel{D, N, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize
600+
ndrange(::Kernel{D, N, WorkgroupSize, NDRange}) where {D, WorkgroupSize, NDRange} = NDRange
599601
backend(kernel::Kernel) = kernel.backend
600602

601603
"""
@@ -658,8 +660,8 @@ Partition a kernel for the given ndrange and workgroupsize.
658660
return iterspace, dynamic
659661
end
660662

661-
function construct(backend::Backend, ::S, ::NDRange, xpu_name::XPUName) where {Backend <: Union{CPU, GPU}, S <: _Size, NDRange <: _Size, XPUName}
662-
return Kernel{Backend, S, NDRange, XPUName}(backend, xpu_name)
663+
function construct(backend::Backend, ::Val{N}, ::S, ::NDRange, xpu_name::XPUName) where {Backend <: Union{CPU, GPU}, N, S <: _Size, NDRange <: _Size, XPUName}
664+
return Kernel{Backend, N, S, NDRange, XPUName}(backend, xpu_name)
663665
end
664666

665667
###

0 commit comments

Comments
 (0)