@@ -192,110 +192,147 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
192
192
193
193
commit! (cmdbuf)
194
194
195
+ wait_completed (cmdbuf)
196
+
195
197
return B
196
198
end
197
199
198
200
201
+ function LinearAlgebra.:(\ )(A:: LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
202
+ C = deepcopy (B)
203
+ LinearAlgebra. ldiv! (A, C)
204
+ return C
205
+ end
206
+
207
+
199
208
function LinearAlgebra. ldiv! (A:: LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
200
- orig = size (B)
201
- M,N = size (B)[1 ], ndims (B) > 1 ? size (B)[2 ] : 1
209
+ M,N = size (B,1 ), size (B,2 )
202
210
dev = current_device ()
203
211
queue = global_queue (dev)
204
212
205
- B = reshape (B, (N,M))
213
+ At = similar (A. factors)
214
+ Bt = similar (B, (N,M))
206
215
P = reshape ((A. ipiv .- UInt32 (1 )), (1 ,M))
207
- X = similar (B)
216
+ X = similar (B, (N,M) )
208
217
209
- mps_a = MPSMatrix (A. factors)
210
- mps_b = MPSMatrix (B)
218
+ transpose! (At, A. factors)
219
+ transpose! (Bt, B)
220
+
221
+ mps_a = MPSMatrix (At)
222
+ mps_b = MPSMatrix (Bt)
211
223
mps_p = MPSMatrix (P)
212
224
mps_x = MPSMatrix (X)
213
225
214
226
MTLCommandBuffer (queue) do cmdbuf
215
- kernel = MPSMatrixSolveLU (dev, true , M, N)
227
+ kernel = MPSMatrixSolveLU (dev, false , M, N)
216
228
encode! (cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
217
229
end
218
230
219
- B . = X
220
- B = reshape (B, orig)
231
+ transpose! (B, X)
232
+ return B
221
233
end
222
234
223
- function LinearAlgebra. ldiv! (A:: UnitUpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
224
- M,N = size (B)
235
+
236
+ function LinearAlgebra. ldiv! (A:: UpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
237
+ M,N = size (B,1 ), size (B,2 )
225
238
dev = current_device ()
226
239
queue = global_queue (dev)
227
- cmdbuf = MTLCommandBuffer (queue)
228
- enqueue! (cmdbuf)
229
240
230
- Bh = reshape (B, )
231
- X = MtlMatrix {T} (undef, size (B))
241
+ Ad = MtlMatrix (A' )
242
+ Br = similar (B, (M,M))
243
+ X = similar (Br)
232
244
233
- mps_a = MPSMatrix (A)
234
- mps_b = MPSMatrix (Bh) # TODO reshape to matrix if B is a vector
245
+ transpose! (Br, B)
246
+
247
+ mps_a = MPSMatrix (Ad)
248
+ mps_b = MPSMatrix (Br)
235
249
mps_x = MPSMatrix (X)
236
250
237
- solve_kernel = MPSMatrixSolveTriangular (dev, false , false , false , true , M, N, 1.0 )
238
- encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
239
- commit! (cmdbuf)
251
+ buf = MTLCommandBuffer (queue) do cmdbuf
252
+ kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , N, M, 1.0 )
253
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
254
+ end
240
255
241
- return X
256
+ wait_completed (buf)
257
+
258
+ copy! (B, X)
259
+ return B
242
260
end
243
261
244
- function LinearAlgebra. ldiv! (A:: LowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
245
- M,N = size (B)
262
+
263
+ function LinearAlgebra. ldiv! (A:: UnitUpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
264
+ M,N = size (B,1 ), size (B,2 )
246
265
dev = current_device ()
247
266
queue = global_queue (dev)
248
- cmdbuf = MTLCommandBuffer (queue)
249
- enqueue! (cmdbuf)
250
267
251
- X = MtlMatrix {T} (undef, size (B))
268
+ Ad = MtlMatrix (A)
269
+ Br = reshape (B, (M,N))
270
+ X = similar (Br)
252
271
253
- mps_a = MPSMatrix (A )
254
- mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
272
+ mps_a = MPSMatrix (Ad )
273
+ mps_b = MPSMatrix (Br)
255
274
mps_x = MPSMatrix (X)
256
275
257
- solve_kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , M, N, 1.0 )
258
- encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
259
- commit! (cmdbuf)
276
+
277
+ buf = MTLCommandBuffer (queue) do cmdbuf
278
+ kernel = MPSMatrixSolveTriangular (dev, true , false , false , true , M, N, 1.0 )
279
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
280
+ end
260
281
261
- return X
282
+ wait_completed (buf)
283
+
284
+ copy! (Br, X)
285
+ return B
262
286
end
263
287
264
- function LinearAlgebra. ldiv! (A:: UnitLowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
265
- M,N = size (B)
288
+
289
+ function LinearAlgebra. ldiv! (A:: LowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
290
+ M,N = size (B,1 ), size (B,2 )
266
291
dev = current_device ()
267
292
queue = global_queue (dev)
268
- cmdbuf = MTLCommandBuffer (queue)
269
- enqueue! (cmdbuf)
270
293
271
- X = MtlMatrix {T} (undef, size (B))
294
+ Ad = MtlMatrix (A)
295
+ Br = reshape (B, (M,N))
296
+ X = similar (Br)
272
297
273
- mps_a = MPSMatrix (A )
274
- mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
298
+ mps_a = MPSMatrix (Ad )
299
+ mps_b = MPSMatrix (Br)
275
300
mps_x = MPSMatrix (X)
276
301
277
- solve_kernel = MPSMatrixSolveTriangular (dev, false , true , false , true , M, N, 1.0 )
278
- encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
279
- commit! (cmdbuf)
302
+
303
+ buf = MTLCommandBuffer (queue) do cmdbuf
304
+ kernel = MPSMatrixSolveTriangular (dev, true , true , false , false , M, N, 1.0 )
305
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
306
+ end
307
+
308
+ wait_completed (buf)
280
309
281
- return X
310
+ copy! (Br, X)
311
+ return B
282
312
end
283
313
284
- # function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
285
- # require_one_based_indexing(A, B)
286
- # m, n = size(A)
287
- # if m == n
288
- # if istril(A)
289
- # if istriu(A)
290
- # return Diagonal(A) \ B
291
- # else
292
- # return LowerTriangular(A) \ B
293
- # end
294
- # end
295
- # if istriu(A)
296
- # return UpperTriangular(A) \ B
297
- # end
298
- # return lu(A) \ B
299
- # end
300
- # return qr(A, ColumnNorm()) \ B
301
- # end
314
+
315
+ function LinearAlgebra. ldiv! (A:: UnitLowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
316
+ M,N = size (B,1 ), size (B,2 )
317
+ dev = current_device ()
318
+ queue = global_queue (dev)
319
+
320
+ Ad = MtlMatrix (A)
321
+ Br = reshape (B, (M,N))
322
+ X = similar (Br)
323
+
324
+ mps_a = MPSMatrix (Ad)
325
+ mps_b = MPSMatrix (Br)
326
+ mps_x = MPSMatrix (X)
327
+
328
+
329
+ buf = MTLCommandBuffer (queue) do cmdbuf
330
+ kernel = MPSMatrixSolveTriangular (dev, true , true , false , true , M, N, 1.0 )
331
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
332
+ end
333
+
334
+ wait_completed (buf)
335
+
336
+ copy! (Br, X)
337
+ return B
338
+ end
0 commit comments