@@ -52,7 +52,7 @@ synchronize(backend)
52
52
```
53
53
"""
54
54
macro kernel (expr)
55
- __kernel (expr, #= generate_cpu=# true , #= force_inbounds=# false )
55
+ __kernel (DynamicSize (), expr, #= generate_cpu=# true , #= force_inbounds=# false )
56
56
end
57
57
58
58
"""
@@ -70,10 +70,11 @@ This allows for two different configurations:
70
70
"""
71
71
macro kernel (ex... )
72
72
if length (ex) == 1
73
- __kernel (ex[1 ], true , false )
73
+ __kernel (DynamicSize (), ex[1 ], true , false )
74
74
else
75
75
generate_cpu = true
76
76
force_inbounds = false
77
+ N = DynamicSize () # TODO parse N
77
78
for i in 1 : (length (ex) - 1 )
78
79
if ex[i] isa Expr && ex[i]. head == :(= ) &&
79
80
ex[i]. args[1 ] == :cpu && ex[i]. args[2 ] isa Bool
@@ -90,7 +91,7 @@ macro kernel(ex...)
90
91
)
91
92
end
92
93
end
93
- __kernel (ex[end ], generate_cpu, force_inbounds)
94
+ __kernel (N, ex[end ], generate_cpu, force_inbounds)
94
95
end
95
96
end
96
97
@@ -586,7 +587,7 @@ in a workgroup.
586
587
```
587
588
As well as the on-device functionality.
588
589
"""
589
- struct Kernel{Backend, N, WorkgroupSize <: _Size , NDRange <: _Size , Fun}
590
+ struct Kernel{Backend, N <: _Size , WorkgroupSize <: _Size , NDRange <: _Size , Fun}
590
591
backend:: Backend
591
592
f:: Fun
592
593
end
@@ -595,8 +596,9 @@ function Base.similar(kernel::Kernel{D, N, WS, ND}, f::F) where {D, N, WS, ND, F
595
596
Kernel {D, N, WS, ND, F} (kernel. backend, f)
596
597
end
597
598
598
- workgroupsize (:: Kernel{D, N, WorkgroupSize} ) where {D, WorkgroupSize} = WorkgroupSize
599
- ndrange (:: Kernel{D, N, WorkgroupSize, NDRange} ) where {D, WorkgroupSize, NDRange} = NDRange
599
+ workgroupsize (:: Kernel{D, N, WorkgroupSize} ) where {D, N, WorkgroupSize} = WorkgroupSize
600
+ ndrange (:: Kernel{D, N, WorkgroupSize, NDRange} ) where {D, N, WorkgroupSize, NDRange} = NDRange
601
+ ndims (:: Kernel{D, N} ) where {D, N} = N
600
602
backend (kernel:: Kernel ) = kernel. backend
601
603
602
604
"""
@@ -605,6 +607,7 @@ Partition a kernel for the given ndrange and workgroupsize.
605
607
@inline function partition (kernel, ndrange, workgroupsize)
606
608
static_ndrange = KernelAbstractions. ndrange (kernel)
607
609
static_workgroupsize = KernelAbstractions. workgroupsize (kernel)
610
+ static_ndims = KernelAbstractions. ndims (kernel)
608
611
609
612
if ndrange === nothing && static_ndrange <: DynamicSize ||
610
613
workgroupsize === nothing && static_workgroupsize <: DynamicSize
@@ -655,11 +658,16 @@ Partition a kernel for the given ndrange and workgroupsize.
655
658
workgroupsize = CartesianIndices (workgroupsize)
656
659
end
657
660
661
+ if static_ndims <: StaticSize
662
+ @assert get (static_ndims) == length (ndrange)
663
+ end
664
+
665
+ # TODO : Add static_ndims
658
666
iterspace = NDRange {length(ndrange), static_blocks, static_workgroupsize} (blocks, workgroupsize)
659
667
return iterspace, dynamic
660
668
end
661
669
662
- function construct (backend:: Backend , :: Val{N} , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , N, S <: _Size , NDRange <: _Size , XPUName}
670
+ function construct (backend:: Backend , :: N , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , N <: _Size , S <: _Size , NDRange <: _Size , XPUName}
663
671
return Kernel {Backend, N, S, NDRange, XPUName} (backend, xpu_name)
664
672
end
665
673
0 commit comments