@@ -17,13 +17,15 @@ using StaticArrays
17
17
using Adapt
18
18
19
19
"""
20
- @kernel function f(args) end
20
+ @kernel [N] function f(args) end
21
21
22
22
Takes a function definition and generates a [`Kernel`](@ref) constructor from it.
23
23
The enclosed function is allowed to contain kernel language constructs.
24
24
In order to call it the kernel has first to be specialized on the backend
25
25
and then invoked on the arguments.
26
26
27
+ The optional `N` parameter can be used to fix the number of dimensions used for the ndrange.
28
+
27
29
# Kernel language
28
30
29
31
- [`@Const`](@ref)
@@ -55,7 +57,7 @@ macro kernel(expr)
55
57
end
56
58
57
59
"""
58
- @kernel config function f(args) end
60
+ @kernel [N] config function f(args) end
59
61
60
62
This allows for two different configurations:
61
63
@@ -585,17 +587,17 @@ in a workgroup.
585
587
```
586
588
As well as the on-device functionality.
587
589
"""
588
- struct Kernel{Backend, WorkgroupSize <: _Size , NDRange <: _Size , Fun}
590
+ struct Kernel{Backend, N, WorkgroupSize <: _Size , NDRange <: _Size , Fun}
589
591
backend:: Backend
590
592
f:: Fun
591
593
end
592
594
593
- function Base. similar (kernel:: Kernel{D, WS, ND} , f:: F ) where {D, WS, ND, F}
594
- Kernel {D, WS, ND, F} (kernel. backend, f)
595
+ function Base. similar (kernel:: Kernel{D, N, WS, ND} , f:: F ) where {D, N , WS, ND, F}
596
+ Kernel {D, N, WS, ND, F} (kernel. backend, f)
595
597
end
596
598
597
- workgroupsize (:: Kernel{D, WorkgroupSize} ) where {D, WorkgroupSize} = WorkgroupSize
598
- ndrange (:: Kernel{D, WorkgroupSize, NDRange} ) where {D, WorkgroupSize, NDRange} = NDRange
599
+ workgroupsize (:: Kernel{D, N, WorkgroupSize} ) where {D, WorkgroupSize} = WorkgroupSize
600
+ ndrange (:: Kernel{D, N, WorkgroupSize, NDRange} ) where {D, WorkgroupSize, NDRange} = NDRange
599
601
backend (kernel:: Kernel ) = kernel. backend
600
602
601
603
"""
@@ -658,8 +660,8 @@ Partition a kernel for the given ndrange and workgroupsize.
658
660
return iterspace, dynamic
659
661
end
660
662
661
- function construct (backend:: Backend , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , S <: _Size , NDRange <: _Size , XPUName}
662
- return Kernel {Backend, S, NDRange, XPUName} (backend, xpu_name)
663
+ function construct (backend:: Backend , :: Val{N} , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , N , S <: _Size , NDRange <: _Size , XPUName}
664
+ return Kernel {Backend, N, S, NDRange, XPUName} (backend, xpu_name)
663
665
end
664
666
665
667
# ##
0 commit comments