Skip to content

Commit 2edbc71

Browse files
committed
add solvers
1 parent 67dcf41 commit 2edbc71

File tree

3 files changed

+180
-0
lines changed

3 files changed

+180
-0
lines changed

lib/mps/MPS.jl

+3
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,7 @@ include("decomposition.jl")
2424
# matrix copy
2525
include("copy.jl")
2626

27+
# solver
28+
include("solve.jl")
29+
2730
end

lib/mps/linalg.jl

+105
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,108 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
192192

193193
return B
194194
end
195+
196+
197+
function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
198+
# TODO
199+
end
200+
201+
function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
202+
M,N = size(B)
203+
dev = current_device()
204+
queue = global_queue(dev)
205+
cmdbuf = MTLCommandBuffer(queue)
206+
enqueue!(cmdbuf)
207+
208+
X = MtlMatrix{T}(undef, size(B))
209+
210+
mps_a = MPSMatrix(A)
211+
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
212+
mps_x = MPSMatrix(X)
213+
214+
solve_kernel = MPSMatrixSolveTriangular(dev, false, false, false, false, M, N, 1.0) # TODO: likely N, M is the correct order
215+
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
216+
commit!(cmdbuf)
217+
218+
return X
219+
end
220+
221+
function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
222+
M,N = size(B)
223+
dev = current_device()
224+
queue = global_queue(dev)
225+
cmdbuf = MTLCommandBuffer(queue)
226+
enqueue!(cmdbuf)
227+
228+
Bh = reshape(B, )
229+
X = MtlMatrix{T}(undef, size(B))
230+
231+
mps_a = MPSMatrix(A)
232+
mps_b = MPSMatrix(Bh) # TODO reshape to matrix if B is a vector
233+
mps_x = MPSMatrix(X)
234+
235+
solve_kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0)
236+
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
237+
commit!(cmdbuf)
238+
239+
return X
240+
end
241+
242+
function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
243+
M,N = size(B)
244+
dev = current_device()
245+
queue = global_queue(dev)
246+
cmdbuf = MTLCommandBuffer(queue)
247+
enqueue!(cmdbuf)
248+
249+
X = MtlMatrix{T}(undef, size(B))
250+
251+
mps_a = MPSMatrix(A)
252+
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
253+
mps_x = MPSMatrix(X)
254+
255+
solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0)
256+
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
257+
commit!(cmdbuf)
258+
259+
return X
260+
end
261+
262+
function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
263+
M,N = size(B)
264+
dev = current_device()
265+
queue = global_queue(dev)
266+
cmdbuf = MTLCommandBuffer(queue)
267+
enqueue!(cmdbuf)
268+
269+
X = MtlMatrix{T}(undef, size(B))
270+
271+
mps_a = MPSMatrix(A)
272+
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
273+
mps_x = MPSMatrix(X)
274+
275+
solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0)
276+
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
277+
commit!(cmdbuf)
278+
279+
return X
280+
end
281+
282+
# function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
283+
# require_one_based_indexing(A, B)
284+
# m, n = size(A)
285+
# if m == n
286+
# if istril(A)
287+
# if istriu(A)
288+
# return Diagonal(A) \ B
289+
# else
290+
# return LowerTriangular(A) \ B
291+
# end
292+
# end
293+
# if istriu(A)
294+
# return UpperTriangular(A) \ B
295+
# end
296+
# return lu(A) \ B
297+
# end
298+
# return qr(A, ColumnNorm()) \ B
299+
# end

lib/mps/solve.jl

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
export MPSMatrixSolveLU
2+
3+
@objcwrapper immutable=false MPSMatrixSolveLU <: MPSMatrixBinaryKernel
4+
5+
function MPSMatrixSolveLU(device, transpose, order, numberOfRightHandSides)
6+
kernel = @objc [MPSMatrixSolveLU alloc]::id{MPSMatrixSolveLU}
7+
obj = MPSMatrixSolveLU(kernel)
8+
finalizer(release, obj)
9+
@objc [obj::id{MPSMatrixSolveLU} initWithDevice:device::id{MTLDevice}
10+
transpose:transpose::Bool
11+
order:order::NSUInteger
12+
numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveLU}
13+
return obj
14+
end
15+
16+
function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveLU, sourceMatrix, rightHandSideMatrix, pivotIndices, solutionMatrix)
17+
@objc [kernel::id{MPSMatrixSolveLU} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
18+
sourceMatrix:sourceMatrix::id{MPSMatrix}
19+
rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
20+
pivotIndices:pivotIndices::id{MPSMatrix}
21+
solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
22+
end
23+
24+
25+
export MPSMatrixSolveTriangular
26+
27+
@objcwrapper immutable=false MPSMatrixSolveTriangular <: MPSMatrixBinaryKernel
28+
29+
function MPSMatrixSolveTriangular(device, right, upper, transpose, unit, order, numberOfRightHandSides, alpha)
30+
kernel = @objc [MPSMatrixSolveTriangular alloc]::id{MPSMatrixSolveTriangular}
31+
obj = MPSMatrixSolveTriangular(kernel)
32+
finalizer(release, obj)
33+
@objc [obj::id{MPSMatrixSolveTriangular} initWithDevice:device::id{MTLDevice}
34+
right:right::Bool
35+
upper:upper::Bool
36+
transpose:transpose::Bool
37+
unit:unit::Bool
38+
order:order::NSUInteger
39+
numberOfRightHandSides:numberOfRightHandSides::NSUInteger
40+
alpha:alpha::Cdouble]::id{MPSMatrixSolveTriangular}
41+
return obj
42+
end
43+
44+
function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveTriangular, sourceMatrix, rightHandSideMatrix, solutionMatrix)
45+
@objc [kernel::id{MPSMatrixSolveTriangular} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
46+
sourceMatrix:sourceMatrix::id{MPSMatrix}
47+
rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
48+
solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
49+
end
50+
51+
52+
export MPSMatrixSolveCholesky
53+
54+
@objcwrapper immutable=false MPSMatrixSolveCholesky <: MPSMatrixBinaryKernel
55+
56+
function MPSMatrixSolveCholesky(device, upper, order, numberOfRightHandSides)
57+
kernel = @objc [MPSMatrixSolveCholesky alloc]::id{MPSMatrixSolveCholesky}
58+
obj = MPSMatrixSolveCholesky(kernel)
59+
finalizer(release, obj)
60+
@objc [obj::id{MPSMatrixSolveCholesky} initWithDevice:device::id{MTLDevice}
61+
upper:upper::Bool
62+
order:order::NSUInteger
63+
numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveCholesky}
64+
return obj
65+
end
66+
67+
function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveCholesky, sourceMatrix, rightHandSideMatrix, solutionMatrix)
68+
@objc [kernel::id{MPSMatrixSolveCholesky} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
69+
sourceMatrix:sourceMatrix::id{MPSMatrix}
70+
rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
71+
solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
72+
end

0 commit comments

Comments
 (0)