@@ -192,3 +192,108 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
192
192
193
193
return B
194
194
end
195
+
196
+
197
+ function LinearAlgebra. ldiv! (A:: LU{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
198
+ # TODO
199
+ end
200
+
201
+ function LinearAlgebra. ldiv! (A:: UpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
202
+ M,N = size (B)
203
+ dev = current_device ()
204
+ queue = global_queue (dev)
205
+ cmdbuf = MTLCommandBuffer (queue)
206
+ enqueue! (cmdbuf)
207
+
208
+ X = MtlMatrix {T} (undef, size (B))
209
+
210
+ mps_a = MPSMatrix (A)
211
+ mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
212
+ mps_x = MPSMatrix (X)
213
+
214
+ solve_kernel = MPSMatrixSolveTriangular (dev, false , false , false , false , M, N, 1.0 ) # TODO : likely N, M is the correct order
215
+ encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
216
+ commit! (cmdbuf)
217
+
218
+ return X
219
+ end
220
+
221
+ function LinearAlgebra. ldiv! (A:: UnitUpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
222
+ M,N = size (B)
223
+ dev = current_device ()
224
+ queue = global_queue (dev)
225
+ cmdbuf = MTLCommandBuffer (queue)
226
+ enqueue! (cmdbuf)
227
+
228
+ Bh = reshape (B, )
229
+ X = MtlMatrix {T} (undef, size (B))
230
+
231
+ mps_a = MPSMatrix (A)
232
+ mps_b = MPSMatrix (Bh) # TODO reshape to matrix if B is a vector
233
+ mps_x = MPSMatrix (X)
234
+
235
+ solve_kernel = MPSMatrixSolveTriangular (dev, false , false , false , true , M, N, 1.0 )
236
+ encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
237
+ commit! (cmdbuf)
238
+
239
+ return X
240
+ end
241
+
242
+ function LinearAlgebra. ldiv! (A:: LowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
243
+ M,N = size (B)
244
+ dev = current_device ()
245
+ queue = global_queue (dev)
246
+ cmdbuf = MTLCommandBuffer (queue)
247
+ enqueue! (cmdbuf)
248
+
249
+ X = MtlMatrix {T} (undef, size (B))
250
+
251
+ mps_a = MPSMatrix (A)
252
+ mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
253
+ mps_x = MPSMatrix (X)
254
+
255
+ solve_kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , M, N, 1.0 )
256
+ encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
257
+ commit! (cmdbuf)
258
+
259
+ return X
260
+ end
261
+
262
+ function LinearAlgebra. ldiv! (A:: UnitLowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
263
+ M,N = size (B)
264
+ dev = current_device ()
265
+ queue = global_queue (dev)
266
+ cmdbuf = MTLCommandBuffer (queue)
267
+ enqueue! (cmdbuf)
268
+
269
+ X = MtlMatrix {T} (undef, size (B))
270
+
271
+ mps_a = MPSMatrix (A)
272
+ mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
273
+ mps_x = MPSMatrix (X)
274
+
275
+ solve_kernel = MPSMatrixSolveTriangular (dev, false , true , false , true , M, N, 1.0 )
276
+ encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
277
+ commit! (cmdbuf)
278
+
279
+ return X
280
+ end
281
+
282
+ # function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
283
+ # require_one_based_indexing(A, B)
284
+ # m, n = size(A)
285
+ # if m == n
286
+ # if istril(A)
287
+ # if istriu(A)
288
+ # return Diagonal(A) \ B
289
+ # else
290
+ # return LowerTriangular(A) \ B
291
+ # end
292
+ # end
293
+ # if istriu(A)
294
+ # return UpperTriangular(A) \ B
295
+ # end
296
+ # return lu(A) \ B
297
+ # end
298
+ # return qr(A, ColumnNorm()) \ B
299
+ # end
0 commit comments