Skip to content

Commit d7833d4

Browse files
committed
fixup N support
1 parent 47a1d8f commit d7833d4

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

src/KernelAbstractions.jl

+15-7
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ synchronize(backend)
5252
```
5353
"""
5454
macro kernel(expr)
55-
__kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false)
55+
__kernel(DynamicSize(), expr, #=generate_cpu=# true, #=force_inbounds=# false)
5656
end
5757

5858
"""
@@ -70,10 +70,11 @@ This allows for two different configurations:
7070
"""
7171
macro kernel(ex...)
7272
if length(ex) == 1
73-
__kernel(ex[1], true, false)
73+
__kernel(DynamicSize(), ex[1], true, false)
7474
else
7575
generate_cpu = true
7676
force_inbounds = false
77+
N = DynamicSize() # TODO parse N
7778
for i in 1:(length(ex) - 1)
7879
if ex[i] isa Expr && ex[i].head == :(=) &&
7980
ex[i].args[1] == :cpu && ex[i].args[2] isa Bool
@@ -90,7 +91,7 @@ macro kernel(ex...)
9091
)
9192
end
9293
end
93-
__kernel(ex[end], generate_cpu, force_inbounds)
94+
__kernel(N, ex[end], generate_cpu, force_inbounds)
9495
end
9596
end
9697

@@ -586,7 +587,7 @@ in a workgroup.
586587
```
587588
As well as the on-device functionality.
588589
"""
589-
struct Kernel{Backend, N, WorkgroupSize <: _Size, NDRange <: _Size, Fun}
590+
struct Kernel{Backend, N <: _Size, WorkgroupSize <: _Size, NDRange <: _Size, Fun}
590591
backend::Backend
591592
f::Fun
592593
end
@@ -595,8 +596,9 @@ function Base.similar(kernel::Kernel{D, N, WS, ND}, f::F) where {D, N, WS, ND, F
595596
Kernel{D, N, WS, ND, F}(kernel.backend, f)
596597
end
597598

598-
workgroupsize(::Kernel{D, N, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize
599-
ndrange(::Kernel{D, N, WorkgroupSize, NDRange}) where {D, WorkgroupSize, NDRange} = NDRange
599+
workgroupsize(::Kernel{D, N, WorkgroupSize}) where {D, N, WorkgroupSize} = WorkgroupSize
600+
ndrange(::Kernel{D, N, WorkgroupSize, NDRange}) where {D, N, WorkgroupSize, NDRange} = NDRange
601+
ndims(::Kernel{D, N}) where {D, N} = N
600602
backend(kernel::Kernel) = kernel.backend
601603

602604
"""
@@ -605,6 +607,7 @@ Partition a kernel for the given ndrange and workgroupsize.
605607
@inline function partition(kernel, ndrange, workgroupsize)
606608
static_ndrange = KernelAbstractions.ndrange(kernel)
607609
static_workgroupsize = KernelAbstractions.workgroupsize(kernel)
610+
static_ndims = KernelAbstractions.ndims(kernel)
608611

609612
if ndrange === nothing && static_ndrange <: DynamicSize ||
610613
workgroupsize === nothing && static_workgroupsize <: DynamicSize
@@ -655,11 +658,16 @@ Partition a kernel for the given ndrange and workgroupsize.
655658
workgroupsize = CartesianIndices(workgroupsize)
656659
end
657660

661+
if static_ndims <: StaticSize
662+
@assert get(static_ndims) == length(ndrange)
663+
end
664+
665+
# TODO: Add static_ndims
658666
iterspace = NDRange{length(ndrange), static_blocks, static_workgroupsize}(blocks, workgroupsize)
659667
return iterspace, dynamic
660668
end
661669

662-
function construct(backend::Backend, ::Val{N}, ::S, ::NDRange, xpu_name::XPUName) where {Backend <: Union{CPU, GPU}, N, S <: _Size, NDRange <: _Size, XPUName}
670+
function construct(backend::Backend, ::N, ::S, ::NDRange, xpu_name::XPUName) where {Backend <: Union{CPU, GPU}, N <: _Size, S <: _Size, NDRange <: _Size, XPUName}
663671
return Kernel{Backend, N, S, NDRange, XPUName}(backend, xpu_name)
664672
end
665673

src/macros.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function find_return(stmt)
1010
end
1111

1212
# XXX: Proper errors
13-
function __kernel(expr, generate_cpu = true, force_inbounds = false)
13+
function __kernel(N, expr, generate_cpu = true, force_inbounds = false)
1414
def = splitdef(expr)
1515
name = def[:name]
1616
args = def[:args]
@@ -57,10 +57,10 @@ function __kernel(expr, generate_cpu = true, force_inbounds = false)
5757
$name(dev, size, range) = $name(dev, $StaticSize(size), $StaticSize(range))
5858
function $name(dev::Dev, sz::S, range::NDRange) where {Dev, S <: $_Size, NDRange <: $_Size}
5959
if $isgpu(dev)
60-
return $construct(dev, sz, range, $gpu_name)
60+
return $construct(dev, $(N), sz, range, $gpu_name)
6161
else
6262
if $generate_cpu
63-
return $construct(dev, sz, range, $cpu_name)
63+
return $construct(dev, $(N), sz, range, $cpu_name)
6464
else
6565
error("This kernel is unavailable for backend CPU")
6666
end

test/test.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ identity(x) = x
1010
function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; skip_tests = Set{String}())
1111
@conditional_testset "partition" skip_tests begin
1212
backend = Backend()
13-
let kernel = KernelAbstractions.Kernel{typeof(backend), StaticSize{(64,)}, DynamicSize, typeof(identity)}(backend, identity)
13+
let kernel = KernelAbstractions.Kernel{typeof(backend), DynamicSize, StaticSize{(64,)}, DynamicSize, typeof(identity)}(backend, identity)
1414
iterspace, dynamic = KernelAbstractions.partition(kernel, (128,), nothing)
1515
@test length(blocks(iterspace)) == 2
1616
@test dynamic isa NoDynamicCheck
@@ -26,7 +26,7 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk
2626
@test_throws ErrorException KernelAbstractions.partition(kernel, (129,), (65,))
2727
@test KernelAbstractions.backend(kernel) == backend
2828
end
29-
let kernel = KernelAbstractions.Kernel{typeof(backend), StaticSize{(64,)}, StaticSize{(128,)}, typeof(identity)}(backend, identity)
29+
let kernel = KernelAbstractions.Kernel{typeof(backend), DynamicSize, StaticSize{(64,)}, StaticSize{(128,)}, typeof(identity)}(backend, identity)
3030
iterspace, dynamic = KernelAbstractions.partition(kernel, (128,), nothing)
3131
@test length(blocks(iterspace)) == 2
3232
@test dynamic isa NoDynamicCheck

0 commit comments

Comments
 (0)