@@ -192,30 +192,43 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
192
192
193
193
commit! (cmdbuf)
194
194
195
+ wait_completed (cmdbuf)
196
+
195
197
return B
196
198
end
197
199
198
200
201
+ function LinearAlgebra.:(\ )(A:: LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
202
+ C = deepcopy (B)
203
+ LinearAlgebra. ldiv! (A, C)
204
+ return C
205
+ end
206
+
207
+
199
208
function LinearAlgebra. ldiv! (A:: LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
200
209
M,N = size (B,1 ), size (B,2 )
201
210
dev = current_device ()
202
211
queue = global_queue (dev)
203
212
204
- Bt = reshape (B, (N,M))
213
+ At = similar (A. factors)
214
+ Bt = similar (B, (N,M))
205
215
P = reshape ((A. ipiv .- UInt32 (1 )), (1 ,M))
206
- X = similar (B)
216
+ X = similar (B, (N,M))
217
+
218
+ transpose! (At, A. factors)
219
+ transpose! (Bt, B)
207
220
208
- mps_a = MPSMatrix (A . factors )
221
+ mps_a = MPSMatrix (At )
209
222
mps_b = MPSMatrix (Bt)
210
223
mps_p = MPSMatrix (P)
211
224
mps_x = MPSMatrix (X)
212
225
213
226
MTLCommandBuffer (queue) do cmdbuf
214
- kernel = MPSMatrixSolveLU (dev, true , M, N)
227
+ kernel = MPSMatrixSolveLU (dev, false , M, N)
215
228
encode! (cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
216
229
end
217
230
218
- Bt . = X
231
+ transpose! (B, X)
219
232
return B
220
233
end
221
234
@@ -225,20 +238,24 @@ function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrM
225
238
dev = current_device ()
226
239
queue = global_queue (dev)
227
240
228
- Ad = MtlMatrix (A; storage= Private)
229
- Bt = reshape (B, (N,M))
230
- X = similar (B)
241
+ Ad = MtlMatrix (A' )
242
+ Br = similar (B, (M,M))
243
+ X = similar (Br)
244
+
245
+ transpose! (Br, B)
231
246
232
247
mps_a = MPSMatrix (Ad)
233
- mps_b = MPSMatrix (Bt )
248
+ mps_b = MPSMatrix (Br )
234
249
mps_x = MPSMatrix (X)
235
250
236
- MTLCommandBuffer (queue) do cmdbuf
237
- kernel = MPSMatrixSolveTriangular (dev, false , false , false , false , M, N , 1.0 )
251
+ buf = MTLCommandBuffer (queue) do cmdbuf
252
+ kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , N, M , 1.0 )
238
253
encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
239
254
end
240
255
241
- Bt .= X
256
+ wait_completed (buf)
257
+
258
+ copy! (B, X)
242
259
return B
243
260
end
244
261
@@ -248,20 +265,23 @@ function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVe
248
265
dev = current_device ()
249
266
queue = global_queue (dev)
250
267
251
- Ad = MtlMatrix (A; storage = Private )
252
- Bt = reshape (B, (N,M ))
253
- X = similar (B )
268
+ Ad = MtlMatrix (A)
269
+ Br = reshape (B, (M,N ))
270
+ X = similar (Br )
254
271
255
272
mps_a = MPSMatrix (Ad)
256
- mps_b = MPSMatrix (Bt )
273
+ mps_b = MPSMatrix (Br )
257
274
mps_x = MPSMatrix (X)
258
275
259
- MTLCommandBuffer (queue) do cmdbuf
260
- kernel = MPSMatrixSolveTriangular (dev, false , false , false , true , M, N, 1.0 )
276
+
277
+ buf = MTLCommandBuffer (queue) do cmdbuf
278
+ kernel = MPSMatrixSolveTriangular (dev, true , false , false , true , M, N, 1.0 )
261
279
encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
262
280
end
263
281
264
- Bt .= X
282
+ wait_completed (buf)
283
+
284
+ copy! (Br, X)
265
285
return B
266
286
end
267
287
@@ -271,20 +291,23 @@ function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrM
271
291
dev = current_device ()
272
292
queue = global_queue (dev)
273
293
274
- Ad = MtlMatrix (A; storage = Private )
275
- Bt = reshape (B, (N,M ))
276
- X = similar (B )
294
+ Ad = MtlMatrix (A)
295
+ Br = reshape (B, (M,N ))
296
+ X = similar (Br )
277
297
278
298
mps_a = MPSMatrix (Ad)
279
- mps_b = MPSMatrix (Bt )
299
+ mps_b = MPSMatrix (Br )
280
300
mps_x = MPSMatrix (X)
281
301
282
- MTLCommandBuffer (queue) do cmdbuf
283
- kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , M, N, 1.0 )
302
+
303
+ buf = MTLCommandBuffer (queue) do cmdbuf
304
+ kernel = MPSMatrixSolveTriangular (dev, true , true , false , false , M, N, 1.0 )
284
305
encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
285
306
end
286
307
287
- Bt .= X
308
+ wait_completed (buf)
309
+
310
+ copy! (Br, X)
288
311
return B
289
312
end
290
313
@@ -294,19 +317,22 @@ function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVe
294
317
dev = current_device ()
295
318
queue = global_queue (dev)
296
319
297
- A = MtlMatrix (A; storage = Private )
298
- Bt = reshape (B, (N,M ))
299
- X = similar (B )
320
+ Ad = MtlMatrix (A)
321
+ Br = reshape (B, (M,N ))
322
+ X = similar (Br )
300
323
301
- mps_a = MPSMatrix (A )
302
- mps_b = MPSMatrix (Bt )
324
+ mps_a = MPSMatrix (Ad )
325
+ mps_b = MPSMatrix (Br )
303
326
mps_x = MPSMatrix (X)
304
327
305
- MTLCommandBuffer (queue) do cmdbuf
306
- kernel = MPSMatrixSolveTriangular (dev, false , true , false , true , M, N, 1.0 )
328
+
329
+ buf = MTLCommandBuffer (queue) do cmdbuf
330
+ kernel = MPSMatrixSolveTriangular (dev, true , true , false , true , M, N, 1.0 )
307
331
encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
308
332
end
309
333
310
- Bt .= X
334
+ wait_completed (buf)
335
+
336
+ copy! (Br, X)
311
337
return B
312
338
end
0 commit comments