Skip to content

Commit 5ab3330

Browse files
committed
baseline
1 parent f5d932a commit 5ab3330

File tree

2 files changed

+95
-34
lines changed

2 files changed

+95
-34
lines changed

Diff for: lib/mps/linalg.jl

+60-34
Original file line numberDiff line numberDiff line change
@@ -192,30 +192,43 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
192192

193193
commit!(cmdbuf)
194194

195+
wait_completed(cmdbuf)
196+
195197
return B
196198
end
197199

198200

201+
function LinearAlgebra.:(\)(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
202+
C = deepcopy(B)
203+
LinearAlgebra.ldiv!(A, C)
204+
return C
205+
end
206+
207+
199208
function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
200209
M,N = size(B,1), size(B,2)
201210
dev = current_device()
202211
queue = global_queue(dev)
203212

204-
Bt = reshape(B, (N,M))
213+
At = similar(A.factors)
214+
Bt = similar(B, (N,M))
205215
P = reshape((A.ipiv .- UInt32(1)), (1,M))
206-
X = similar(B)
216+
X = similar(B, (N,M))
217+
218+
transpose!(At, A.factors)
219+
transpose!(Bt, B)
207220

208-
mps_a = MPSMatrix(A.factors)
221+
mps_a = MPSMatrix(At)
209222
mps_b = MPSMatrix(Bt)
210223
mps_p = MPSMatrix(P)
211224
mps_x = MPSMatrix(X)
212225

213226
MTLCommandBuffer(queue) do cmdbuf
214-
kernel = MPSMatrixSolveLU(dev, true, M, N)
227+
kernel = MPSMatrixSolveLU(dev, false, M, N)
215228
encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
216229
end
217230

218-
Bt .= X
231+
transpose!(B, X)
219232
return B
220233
end
221234

@@ -225,20 +238,24 @@ function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrM
225238
dev = current_device()
226239
queue = global_queue(dev)
227240

228-
Ad = MtlMatrix(A; storage=Private)
229-
Bt = reshape(B, (N,M))
230-
X = similar(B)
241+
Ad = MtlMatrix(A')
242+
Br = similar(B, (M,M))
243+
X = similar(Br)
244+
245+
transpose!(Br, B)
231246

232247
mps_a = MPSMatrix(Ad)
233-
mps_b = MPSMatrix(Bt)
248+
mps_b = MPSMatrix(Br)
234249
mps_x = MPSMatrix(X)
235250

236-
MTLCommandBuffer(queue) do cmdbuf
237-
kernel = MPSMatrixSolveTriangular(dev, false, false, false, false, M, N, 1.0)
251+
buf = MTLCommandBuffer(queue) do cmdbuf
252+
kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0)
238253
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
239254
end
240255

241-
Bt .= X
256+
wait_completed(buf)
257+
258+
copy!(B, X)
242259
return B
243260
end
244261

@@ -248,20 +265,23 @@ function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVe
248265
dev = current_device()
249266
queue = global_queue(dev)
250267

251-
Ad = MtlMatrix(A; storage=Private)
252-
Bt = reshape(B, (N,M))
253-
X = similar(B)
268+
Ad = MtlMatrix(A)
269+
Br = reshape(B, (M,N))
270+
X = similar(Br)
254271

255272
mps_a = MPSMatrix(Ad)
256-
mps_b = MPSMatrix(Bt)
273+
mps_b = MPSMatrix(Br)
257274
mps_x = MPSMatrix(X)
258275

259-
MTLCommandBuffer(queue) do cmdbuf
260-
kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0)
276+
277+
buf = MTLCommandBuffer(queue) do cmdbuf
278+
kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0)
261279
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
262280
end
263281

264-
Bt .= X
282+
wait_completed(buf)
283+
284+
copy!(Br, X)
265285
return B
266286
end
267287

@@ -271,20 +291,23 @@ function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrM
271291
dev = current_device()
272292
queue = global_queue(dev)
273293

274-
Ad = MtlMatrix(A; storage=Private)
275-
Bt = reshape(B, (N,M))
276-
X = similar(B)
294+
Ad = MtlMatrix(A)
295+
Br = reshape(B, (M,N))
296+
X = similar(Br)
277297

278298
mps_a = MPSMatrix(Ad)
279-
mps_b = MPSMatrix(Bt)
299+
mps_b = MPSMatrix(Br)
280300
mps_x = MPSMatrix(X)
281301

282-
MTLCommandBuffer(queue) do cmdbuf
283-
kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0)
302+
303+
buf = MTLCommandBuffer(queue) do cmdbuf
304+
kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0)
284305
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
285306
end
286307

287-
Bt .= X
308+
wait_completed(buf)
309+
310+
copy!(Br, X)
288311
return B
289312
end
290313

@@ -294,19 +317,22 @@ function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVe
294317
dev = current_device()
295318
queue = global_queue(dev)
296319

297-
A = MtlMatrix(A; storage=Private)
298-
Bt = reshape(B, (N,M))
299-
X = similar(B)
320+
Ad = MtlMatrix(A)
321+
Br = reshape(B, (M,N))
322+
X = similar(Br)
300323

301-
mps_a = MPSMatrix(A)
302-
mps_b = MPSMatrix(Bt)
324+
mps_a = MPSMatrix(Ad)
325+
mps_b = MPSMatrix(Br)
303326
mps_x = MPSMatrix(X)
304327

305-
MTLCommandBuffer(queue) do cmdbuf
306-
kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0)
328+
329+
buf = MTLCommandBuffer(queue) do cmdbuf
330+
kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0)
307331
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
308332
end
309333

310-
Bt .= X
334+
wait_completed(buf)
335+
336+
copy!(Br, X)
311337
return B
312338
end

Diff for: test/mps.jl

+35
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,39 @@ end
5858
@test_throws SingularException lu(A)
5959
end
6060

61+
@testset "solves" begin
62+
b = MtlVector(rand(Float32, 1024))
63+
B = MtlMatrix(rand(Float32, 1024, 1024))
64+
65+
A = MtlMatrix(rand(Float32, 1024, 512))
66+
x = lu(A) \ b
67+
@test A * x b
68+
X = lu(A) \ B
69+
@test A * X B
70+
71+
A = UpperTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
72+
x = A \ b
73+
@test A * x b
74+
X = A \ B
75+
@test A * X B
76+
77+
A = UnitUpperTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
78+
x = A \ b
79+
@test A * x b
80+
X = A \ B
81+
@test A * X B
82+
83+
A = LowerTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
84+
x = A \ b
85+
@test A * x b
86+
X = A \ B
87+
@test A * X B
88+
89+
A = UnitLowerTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
90+
x = A \ b
91+
@test A * x b
92+
X = A \ B
93+
@test A * X B
94+
end
95+
6196
end

0 commit comments

Comments
 (0)