Skip to content

Commit efbf84e

Browse files
Add new blas/lapack interfaces (#6658)
* Add blas_copy * Add blas_nrm2 * Fix blas_nrm2 unittest * Fix scnrm2_ return type * Sort and classify lapack routines * Add geqrf * add geqrf lapack C interface * geqrf_inplace with tests * Comment test auxiliary code to be used later * Add the description of cusolver_utils.h, temporarily disabled * Update heevd interface to add lda * Update lapack_test to new interface --------- Co-authored-by: Mohan Chen <[email protected]>
1 parent 46bc1a4 commit efbf84e

File tree

14 files changed

+1295
-252
lines changed

14 files changed

+1295
-252
lines changed

source/source_base/module_container/ATen/kernels/blas.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,32 @@
33
namespace container {
44
namespace kernels {
55

6+
7+
template <typename T>
8+
struct blas_copy<T, DEVICE_CPU> {
9+
void operator()(
10+
const int n,
11+
const T *x,
12+
const int incx,
13+
T *y,
14+
const int incy)
15+
{
16+
BlasConnector::copy(n, x, incx, y, incy);
17+
}
18+
};
19+
20+
template <typename T>
21+
struct blas_nrm2<T, DEVICE_CPU> {
22+
using Real = typename GetTypeReal<T>::type;
23+
Real operator()(
24+
const int n,
25+
const T *x,
26+
const int incx)
27+
{
28+
return BlasConnector::nrm2(n, x, incx);
29+
}
30+
};
31+
632
template <typename T>
733
struct blas_dot<T, DEVICE_CPU> {
834
void operator()(
@@ -175,6 +201,17 @@ struct blas_gemm_batched_strided<T, DEVICE_CPU> {
175201
};
176202

177203
// Explicitly instantiate functors for the types of functor registered.
204+
205+
template struct blas_copy<float , DEVICE_CPU>;
206+
template struct blas_copy<double, DEVICE_CPU>;
207+
template struct blas_copy<std::complex<float >, DEVICE_CPU>;
208+
template struct blas_copy<std::complex<double>, DEVICE_CPU>;
209+
210+
template struct blas_nrm2<float , DEVICE_CPU>;
211+
template struct blas_nrm2<double, DEVICE_CPU>;
212+
template struct blas_nrm2<std::complex<float >, DEVICE_CPU>;
213+
template struct blas_nrm2<std::complex<double>, DEVICE_CPU>;
214+
178215
template struct blas_dot<float , DEVICE_CPU>;
179216
template struct blas_dot<double, DEVICE_CPU>;
180217
template struct blas_dot<std::complex<float >, DEVICE_CPU>;
@@ -221,4 +258,4 @@ template struct blas_gemm_batched_strided<std::complex<float >, DEVICE_CPU>;
221258
template struct blas_gemm_batched_strided<std::complex<double>, DEVICE_CPU>;
222259

223260
} // namespace kernels
224-
} // namespace container
261+
} // namespace container

source/source_base/module_container/ATen/kernels/blas.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,26 @@
99
namespace container {
1010
namespace kernels {
1111

12+
template <typename T, typename Device>
13+
struct blas_copy {
14+
// DCOPY copies a vector, x, to a vector, y.
15+
void operator()(
16+
const int n,
17+
const T *x,
18+
const int incx,
19+
T *y,
20+
const int incy);
21+
};
22+
23+
template <typename T, typename Device>
24+
struct blas_nrm2 {
25+
using Real = typename GetTypeReal<T>::type;
26+
Real operator()(
27+
const int n,
28+
const T *x,
29+
const int incx);
30+
};
31+
1232
template <typename T, typename Device>
1333
struct blas_dot {
1434
void operator()(
@@ -168,4 +188,4 @@ void destroyGpuBlasHandle(); // destory blas handle
168188
} // namespace kernels
169189
} // namespace container
170190

171-
#endif // ATEN_KERNELS_BLAS_H_
191+
#endif // ATEN_KERNELS_BLAS_H_

source/source_base/module_container/ATen/kernels/cuda/blas.cu

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,32 @@ void destroyGpuBlasHandle() {
2222
}
2323
}
2424

25+
template <typename T>
26+
struct blas_nrm2<T, DEVICE_GPU> {
27+
using Real = typename GetTypeReal<T>::type;
28+
Real operator()(
29+
const int n,
30+
const T *x,
31+
const int incx)
32+
{
33+
Real result;
34+
cuBlasConnector::nrm2(cublas_handle, n, x, incx, &result);
35+
return result;
36+
}
37+
};
38+
39+
template <typename T>
40+
struct blas_copy<T, DEVICE_GPU> {
41+
void operator()(
42+
const int n,
43+
const T * x,
44+
const int incx,
45+
T *y,
46+
const int incy)
47+
{
48+
cuBlasConnector::copy(cublas_handle, n, x, incx, y, incy);
49+
}
50+
};
2551

2652
template <typename T>
2753
struct blas_dot<T, DEVICE_GPU> {
@@ -76,7 +102,7 @@ struct blas_gemv<T, DEVICE_GPU> {
76102
const int& incx,
77103
const T* beta,
78104
T* y,
79-
const int& incy)
105+
const int& incy)
80106
{
81107
cuBlasConnector::gemv(cublas_handle, trans, m, n, *alpha, A, lda, x, incx, *beta, y, incy);
82108
}
@@ -196,6 +222,19 @@ struct blas_gemm_batched_strided<T, DEVICE_GPU> {
196222
};
197223

198224
// Explicitly instantiate functors for the types of functor registered.
225+
226+
227+
228+
template struct blas_copy<float , DEVICE_GPU>;
229+
template struct blas_copy<double, DEVICE_GPU>;
230+
template struct blas_copy<std::complex<float> , DEVICE_GPU>;
231+
template struct blas_copy<std::complex<double>, DEVICE_GPU>;
232+
233+
template struct blas_nrm2<float , DEVICE_GPU>;
234+
template struct blas_nrm2<double, DEVICE_GPU>;
235+
template struct blas_nrm2<std::complex<float> , DEVICE_GPU>;
236+
template struct blas_nrm2<std::complex<double>, DEVICE_GPU>;
237+
199238
template struct blas_dot<float , DEVICE_GPU>;
200239
template struct blas_dot<double, DEVICE_GPU>;
201240
template struct blas_dot<std::complex<float> , DEVICE_GPU>;
@@ -242,4 +281,4 @@ template struct blas_gemm_batched_strided<std::complex<float >, DEVICE_GPU>;
242281
template struct blas_gemm_batched_strided<std::complex<double>, DEVICE_GPU>;
243282

244283
} // namespace kernels
245-
} // namespace container
284+
} // namespace container

0 commit comments

Comments
 (0)