Skip to content

Commit 80af119

Browse files
committed
fix other solvers
1 parent ef68e4d commit 80af119

File tree

2 files changed

+133
-61
lines changed

2 files changed

+133
-61
lines changed

lib/mps/linalg.jl

+98-61
Original file line numberDiff line numberDiff line change
@@ -192,110 +192,147 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
192192

193193
commit!(cmdbuf)
194194

195+
wait_completed(cmdbuf)
196+
195197
return B
196198
end
197199

198200

201+
function LinearAlgebra.:(\)(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
202+
C = deepcopy(B)
203+
LinearAlgebra.ldiv!(A, C)
204+
return C
205+
end
206+
207+
199208
function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
200-
orig = size(B)
201-
M,N = size(B)[1], ndims(B) > 1 ? size(B)[2] : 1
209+
M,N = size(B,1), size(B,2)
202210
dev = current_device()
203211
queue = global_queue(dev)
204212

205-
B = reshape(B, (N,M))
213+
At = similar(A.factors)
214+
Bt = similar(B, (N,M))
206215
P = reshape((A.ipiv .- UInt32(1)), (1,M))
207-
X = similar(B)
216+
X = similar(B, (N,M))
208217

209-
mps_a = MPSMatrix(A.factors)
210-
mps_b = MPSMatrix(B)
218+
transpose!(At, A.factors)
219+
transpose!(Bt, B)
220+
221+
mps_a = MPSMatrix(At)
222+
mps_b = MPSMatrix(Bt)
211223
mps_p = MPSMatrix(P)
212224
mps_x = MPSMatrix(X)
213225

214226
MTLCommandBuffer(queue) do cmdbuf
215-
kernel = MPSMatrixSolveLU(dev, true, M, N)
227+
kernel = MPSMatrixSolveLU(dev, false, M, N)
216228
encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
217229
end
218230

219-
B .= X
220-
B = reshape(B, orig)
231+
transpose!(B, X)
232+
return B
221233
end
222234

223-
function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
224-
M,N = size(B)
235+
236+
function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
237+
M,N = size(B,1), size(B,2)
225238
dev = current_device()
226239
queue = global_queue(dev)
227-
cmdbuf = MTLCommandBuffer(queue)
228-
enqueue!(cmdbuf)
229240

230-
Bh = reshape(B, )
231-
X = MtlMatrix{T}(undef, size(B))
241+
Ad = MtlMatrix(A')
242+
Br = similar(B, (M,M))
243+
X = similar(Br)
232244

233-
mps_a = MPSMatrix(A)
234-
mps_b = MPSMatrix(Bh) # TODO reshape to matrix if B is a vector
245+
transpose!(Br, B)
246+
247+
mps_a = MPSMatrix(Ad)
248+
mps_b = MPSMatrix(Br)
235249
mps_x = MPSMatrix(X)
236250

237-
solve_kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0)
238-
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
239-
commit!(cmdbuf)
251+
buf = MTLCommandBuffer(queue) do cmdbuf
252+
kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0)
253+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
254+
end
240255

241-
return X
256+
wait_completed(buf)
257+
258+
copy!(B, X)
259+
return B
242260
end
243261

244-
function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
245-
M,N = size(B)
262+
263+
function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
264+
M,N = size(B,1), size(B,2)
246265
dev = current_device()
247266
queue = global_queue(dev)
248-
cmdbuf = MTLCommandBuffer(queue)
249-
enqueue!(cmdbuf)
250267

251-
X = MtlMatrix{T}(undef, size(B))
268+
Ad = MtlMatrix(A)
269+
Br = reshape(B, (M,N))
270+
X = similar(Br)
252271

253-
mps_a = MPSMatrix(A)
254-
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
272+
mps_a = MPSMatrix(Ad)
273+
mps_b = MPSMatrix(Br)
255274
mps_x = MPSMatrix(X)
256275

257-
solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0)
258-
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
259-
commit!(cmdbuf)
276+
277+
buf = MTLCommandBuffer(queue) do cmdbuf
278+
kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0)
279+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
280+
end
260281

261-
return X
282+
wait_completed(buf)
283+
284+
copy!(Br, X)
285+
return B
262286
end
263287

264-
function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
265-
M,N = size(B)
288+
289+
function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
290+
M,N = size(B,1), size(B,2)
266291
dev = current_device()
267292
queue = global_queue(dev)
268-
cmdbuf = MTLCommandBuffer(queue)
269-
enqueue!(cmdbuf)
270293

271-
X = MtlMatrix{T}(undef, size(B))
294+
Ad = MtlMatrix(A)
295+
Br = reshape(B, (M,N))
296+
X = similar(Br)
272297

273-
mps_a = MPSMatrix(A)
274-
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
298+
mps_a = MPSMatrix(Ad)
299+
mps_b = MPSMatrix(Br)
275300
mps_x = MPSMatrix(X)
276301

277-
solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0)
278-
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
279-
commit!(cmdbuf)
302+
303+
buf = MTLCommandBuffer(queue) do cmdbuf
304+
kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0)
305+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
306+
end
307+
308+
wait_completed(buf)
280309

281-
return X
310+
copy!(Br, X)
311+
return B
282312
end
283313

284-
# function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
285-
# require_one_based_indexing(A, B)
286-
# m, n = size(A)
287-
# if m == n
288-
# if istril(A)
289-
# if istriu(A)
290-
# return Diagonal(A) \ B
291-
# else
292-
# return LowerTriangular(A) \ B
293-
# end
294-
# end
295-
# if istriu(A)
296-
# return UpperTriangular(A) \ B
297-
# end
298-
# return lu(A) \ B
299-
# end
300-
# return qr(A, ColumnNorm()) \ B
301-
# end
314+
315+
function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
316+
M,N = size(B,1), size(B,2)
317+
dev = current_device()
318+
queue = global_queue(dev)
319+
320+
Ad = MtlMatrix(A)
321+
Br = reshape(B, (M,N))
322+
X = similar(Br)
323+
324+
mps_a = MPSMatrix(Ad)
325+
mps_b = MPSMatrix(Br)
326+
mps_x = MPSMatrix(X)
327+
328+
329+
buf = MTLCommandBuffer(queue) do cmdbuf
330+
kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0)
331+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
332+
end
333+
334+
wait_completed(buf)
335+
336+
copy!(Br, X)
337+
return B
338+
end

test/mps.jl

+35
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,39 @@ end
5858
@test_throws SingularException lu(A)
5959
end
6060

61+
@testset "solves" begin
62+
b = MtlVector(rand(Float32, 1024))
63+
B = MtlMatrix(rand(Float32, 1024, 1024))
64+
65+
A = MtlMatrix(rand(Float32, 1024, 512))
66+
x = lu(A) \ b
67+
@test A * x b
68+
X = lu(A) \ B
69+
@test A * X B
70+
71+
A = UpperTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
72+
x = A \ b
73+
@test A * x b
74+
X = A \ B
75+
@test A * X B
76+
77+
A = UnitUpperTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
78+
x = A \ b
79+
@test A * x b
80+
X = A \ B
81+
@test A * X B
82+
83+
A = LowerTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
84+
x = A \ b
85+
@test A * x b
86+
X = A \ B
87+
@test A * X B
88+
89+
A = UnitLowerTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
90+
x = A \ b
91+
@test A * x b
92+
X = A \ B
93+
@test A * X B
94+
end
95+
6196
end

0 commit comments

Comments
 (0)