Skip to content

Commit 67dcf41

Browse files
committed
Move MPSKernels into a dedicated file
1 parent 41c6018 commit 67dcf41

File tree

4 files changed

+45
-36
lines changed

4 files changed

+45
-36
lines changed

lib/mps/MPS.jl

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import GPUArrays
99

1010
is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice},), dev)
1111

12+
# MPS kernel base clases
13+
include("kernel.jl")
14+
1215
# high-level wrappers
1316
include("matrix.jl")
1417

lib/mps/decomposition.jl

-12
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,6 @@
77
end
88

99

10-
export MPSMatrixUnaryKernel
11-
12-
@objcwrapper immutable=false MPSMatrixUnaryKernel <: MPSKernel
13-
14-
@objcproperties MPSMatrixUnaryKernel begin
15-
@autoproperty sourceMatrixOrigin::id{MTLOrigin} setter=setSourceMatrixOrigin
16-
@autoproperty resultMatrixOrigin::id{MTLOrigin} setter=setResultMatrixOrigin
17-
@autoproperty batchStart::NSUInteger setter=setBatchStart
18-
@autoproperty batchSize::NSUInteger setter=setBatchSize
19-
end
20-
21-
2210
export MPSMatrixDecompositionLU
2311

2412
@objcwrapper immutable=false MPSMatrixDecompositionLU <: MPSMatrixUnaryKernel

lib/mps/kernel.jl

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#
2+
# kernels
3+
#
4+
5+
@cenum MPSKernelOptions::NSUInteger begin
6+
MPSKernelOptionsNone = 0
7+
MPSKernelOptionsSkipAPIValidation = 1 << 0
8+
MPSKernelOptionsAllowReducedPrecision = 1 << 1
9+
MPSKernelOptionsDisableInternalTiling = 1 << 2
10+
MPSKernelOptionsInsertDebugGroups = 1 << 3
11+
MPSKernelOptionsVerbose = 1 << 4
12+
end
13+
14+
15+
@objcwrapper MPSKernel <: NSObject
16+
17+
@objcproperties MPSKernel begin
18+
@autoproperty device::id{MTLDevice}
19+
@autoproperty label::id{NSString} setter=setLabel
20+
@autoproperty options::MPSKernelOptions setter=setOptions
21+
end
22+
23+
24+
@objcwrapper immutable=false MPSMatrixUnaryKernel <: MPSKernel
25+
26+
@objcproperties MPSMatrixUnaryKernel begin
27+
@autoproperty sourceMatrixOrigin::id{MTLOrigin} setter=setSourceMatrixOrigin
28+
@autoproperty resultMatrixOrigin::id{MTLOrigin} setter=setResultMatrixOrigin
29+
@autoproperty batchStart::NSUInteger setter=setBatchStart
30+
@autoproperty batchSize::NSUInteger setter=setBatchSize
31+
end
32+
33+
34+
@objcwrapper immutable=false MPSMatrixBinaryKernel <: MPSKernel
35+
36+
@objcproperties MPSMatrixUnaryKernel begin
37+
@autoproperty primarySourceMatrixOrigin::id{MTLOrigin} setter=setPrimarySourceMatrixOrigin
38+
@autoproperty secondarySourceMatrixOrigin::id{MTLOrigin} setter=setSecondarySourceMatrixOrigin
39+
@autoproperty resultMatrixOrigin::id{MTLOrigin} setter=setResultMatrixOrigin
40+
@autoproperty batchStart::NSUInteger setter=setBatchStart
41+
@autoproperty batchSize::NSUInteger setter=setBatchSize
42+
end

lib/mps/matrix.jl

-24
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,6 @@ end
1212
## bitwise operations lose type information, so allow conversions
1313
Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)
1414

15-
@cenum MPSKernelOptions::NSUInteger begin
16-
MPSKernelOptionsNone = 0
17-
MPSKernelOptionsSkipAPIValidation = 1 << 0
18-
MPSKernelOptionsAllowReducedPrecision = 1 << 1
19-
MPSKernelOptionsDisableInternalTiling = 1 << 2
20-
MPSKernelOptionsInsertDebugGroups = 1 << 3
21-
MPSKernelOptionsVerbose = 1 << 4
22-
end
23-
24-
2515
#
2616
# matrix descriptor
2717
#
@@ -87,20 +77,6 @@ function MPSMatrix(arr::MtlMatrix{T}) where T
8777
return obj
8878
end
8979

90-
91-
#
92-
# kernels
93-
#
94-
95-
@objcwrapper MPSKernel <: NSObject
96-
97-
@objcproperties MPSKernel begin
98-
@autoproperty device::id{MTLDevice}
99-
@autoproperty label::id{NSString} setter=setLabel
100-
@autoproperty options::MPSKernelOptions setter=setOptions
101-
end
102-
103-
10480
#
10581
# matrix multiplication
10682
#

0 commit comments

Comments
 (0)