diff --git a/source/source_base/module_container/ATen/kernels/blas.cpp b/source/source_base/module_container/ATen/kernels/blas.cpp index eb192a7c9e..5935ed7c28 100644 --- a/source/source_base/module_container/ATen/kernels/blas.cpp +++ b/source/source_base/module_container/ATen/kernels/blas.cpp @@ -3,6 +3,32 @@ namespace container { namespace kernels { + +template +struct blas_copy { + void operator()( + const int n, + const T *x, + const int incx, + T *y, + const int incy) + { + BlasConnector::copy(n, x, incx, y, incy); + } +}; + +template +struct blas_nrm2 { + using Real = typename GetTypeReal::type; + Real operator()( + const int n, + const T *x, + const int incx) + { + return BlasConnector::nrm2(n, x, incx); + } +}; + template struct blas_dot { void operator()( @@ -175,6 +201,17 @@ struct blas_gemm_batched_strided { }; // Explicitly instantiate functors for the types of functor registered. + +template struct blas_copy; +template struct blas_copy; +template struct blas_copy, DEVICE_CPU>; +template struct blas_copy, DEVICE_CPU>; + +template struct blas_nrm2; +template struct blas_nrm2; +template struct blas_nrm2, DEVICE_CPU>; +template struct blas_nrm2, DEVICE_CPU>; + template struct blas_dot; template struct blas_dot; template struct blas_dot, DEVICE_CPU>; @@ -221,4 +258,4 @@ template struct blas_gemm_batched_strided, DEVICE_CPU>; template struct blas_gemm_batched_strided, DEVICE_CPU>; } // namespace kernels -} // namespace container \ No newline at end of file +} // namespace container diff --git a/source/source_base/module_container/ATen/kernels/blas.h b/source/source_base/module_container/ATen/kernels/blas.h index 201021199c..550caa2f79 100644 --- a/source/source_base/module_container/ATen/kernels/blas.h +++ b/source/source_base/module_container/ATen/kernels/blas.h @@ -9,6 +9,26 @@ namespace container { namespace kernels { +template +struct blas_copy { + // DCOPY copies a vector, x, to a vector, y. + void operator()( + const int n, + const T *x, + const int incx, + T *y, + const int incy); +}; + +template +struct blas_nrm2 { + using Real = typename GetTypeReal::type; + Real operator()( + const int n, + const T *x, + const int incx); +}; + template struct blas_dot { void operator()( @@ -168,4 +188,4 @@ void destroyGpuBlasHandle(); // destory blas handle } // namespace kernels } // namespace container -#endif // ATEN_KERNELS_BLAS_H_ \ No newline at end of file +#endif // ATEN_KERNELS_BLAS_H_ diff --git a/source/source_base/module_container/ATen/kernels/cuda/blas.cu b/source/source_base/module_container/ATen/kernels/cuda/blas.cu index 8d4b5ea227..e8fe5a80bf 100644 --- a/source/source_base/module_container/ATen/kernels/cuda/blas.cu +++ b/source/source_base/module_container/ATen/kernels/cuda/blas.cu @@ -22,6 +22,32 @@ void destroyGpuBlasHandle() { } } +template +struct blas_nrm2 { + using Real = typename GetTypeReal::type; + Real operator()( + const int n, + const T *x, + const int incx) + { + Real result; + cuBlasConnector::nrm2(cublas_handle, n, x, incx, &result); + return result; + } +}; + +template +struct blas_copy { + void operator()( + const int n, + const T * x, + const int incx, + T *y, + const int incy) + { + cuBlasConnector::copy(cublas_handle, n, x, incx, y, incy); + } +}; template struct blas_dot { @@ -76,7 +102,7 @@ struct blas_gemv { const int& incx, const T* beta, T* y, - const int& incy) + const int& incy) { cuBlasConnector::gemv(cublas_handle, trans, m, n, *alpha, A, lda, x, incx, *beta, y, incy); } @@ -196,6 +222,19 @@ struct blas_gemm_batched_strided { }; // Explicitly instantiate functors for the types of functor registered. + + + +template struct blas_copy; +template struct blas_copy; +template struct blas_copy , DEVICE_GPU>; +template struct blas_copy, DEVICE_GPU>; + +template struct blas_nrm2; +template struct blas_nrm2; +template struct blas_nrm2 , DEVICE_GPU>; +template struct blas_nrm2, DEVICE_GPU>; + template struct blas_dot; template struct blas_dot; template struct blas_dot , DEVICE_GPU>; @@ -242,4 +281,4 @@ template struct blas_gemm_batched_strided, DEVICE_GPU>; template struct blas_gemm_batched_strided, DEVICE_GPU>; } // namespace kernels -} // namespace container \ No newline at end of file +} // namespace container diff --git a/source/source_base/module_container/ATen/kernels/cuda/lapack.cu b/source/source_base/module_container/ATen/kernels/cuda/lapack.cu index 96b24f243a..a76c16ef4b 100644 --- a/source/source_base/module_container/ATen/kernels/cuda/lapack.cu +++ b/source/source_base/module_container/ATen/kernels/cuda/lapack.cu @@ -62,6 +62,9 @@ struct set_matrix { } }; + + +// --- 1. Matrix Decomposition --- template struct lapack_trtri { void operator()( @@ -90,17 +93,152 @@ struct lapack_potrf { } }; +template +struct lapack_getrf { + void operator()( + const int& m, + const int& n, + T* Mat, + const int& lda, + int* ipiv) + { + cuSolverConnector::getrf(cusolver_handle, m, n, Mat, lda, ipiv); + } +}; + +template +struct lapack_getri { + void operator()( + const int& n, + T* Mat, + const int& lda, + const int* ipiv, + T* work, + const int& lwork) + { + throw std::runtime_error("cuSOLVER does not provide LU-based matrix inversion interface (getri). To compute the inverse on GPU, use getrs instead."); + } +}; + + +template +struct lapack_geqrf_inplace { + void operator()( + const int m, + const int n, + T *d_A, + const int lda) + { + const int k = std::min(m, n); + + // Allocate tau on device + T *d_tau; + cudaErrcheck(cudaMalloc(&d_tau, sizeof(T) * k)); + + cuSolverConnector::geqrf(cusolver_handle, m, n, d_A, lda, d_tau); + + cuSolverConnector::orgqr(cusolver_handle, m, n, k, d_A, lda, d_tau); + + cudaErrcheck(cudaFree(d_tau)); + + // // geqrf: workspace query + + // // In practice, we use helper function to get lwork + // // Or use magma for better interface + // // Let's assume we have a way to get lwork + // // For now, do a dummy call to get it + // size_t workspaceInBytes = 0; + // cusolverErrcheck(cusolverDnXgeqrf_bufferSize( + // cusolverH, m, n, + // getCudaDataType::type, d_A, lda, + // getCudaDataType::type, // for tau + // CUDA_R_32F, // numerical precision + // CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes)); + + // lwork = static_cast(workspaceInBytes / sizeof(T)); + + // // Allocate workspace + // T *d_work; + // cudaErrcheck(cudaMalloc(&d_work, sizeof(T) * lwork)); + + // // 3. Perform geqrf + // cusolverErrcheck(cusolverDnXgeqrf( + // cusolverH, m, n, + // getCudaDataType::type, d_A, lda, + // d_tau, + // getCudaDataType::type, + // d_work, lwork * sizeof(T), + // d_info)); + + // int info; + // cudaErrcheck(cudaMemcpy(&info, d_info, sizeof(int), cudaMemcpyDeviceToHost)); + // if (info != 0) { + // throw std::runtime_error("cuSOLVER geqrf failed with info = " + std::to_string(info)); + // } + + // // 4. Generate Q using orgqr + // // Query workspace for orgqr + // cusolverErrcheck(cusolverDnXorgqr_bufferSize( + // cusolverH, m, n, k, + // getCudaDataType::type, d_A, lda, + // getCudaDataType::type, d_tau, + // CUDA_R_32F, + // CUSOLVER_WORKSPACE_QUERY_USE_MAX, &workspaceInBytes)); + + // lwork = static_cast(workspaceInBytes / sizeof(T)); + // cudaErrcheck(cudaRealloc(&d_work, sizeof(T) * lwork)); // or realloc + + // // orgqr: generate Q + // cusolverErrcheck(cusolverDnXorgqr( + // cusolverH, m, n, k, + // getCudaDataType::type, d_A, lda, + // getCudaDataType::type, d_tau, + // d_work, lwork * sizeof(T), + // d_info)); + + // cudaErrcheck(cudaMemcpy(&info, d_info, sizeof(int), cudaMemcpyDeviceToHost)); + // if (info != 0) { + // throw std::runtime_error("cuSOLVER orgqr failed with info = " + std::to_string(info)); + // } + + // // Clean up + // cudaErrcheck(cudaFree(d_tau)); + // cudaErrcheck(cudaFree(d_work)); + // cudaErrcheck(cudaFree(d_info)); + } +}; + +// --- 2. Linear System Solvers --- +template +struct lapack_getrs { + void operator()( + const char& trans, + const int& n, + const int& nrhs, + T* A, + const int& lda, + const int* ipiv, + T* B, + const int& ldb) + { + cuSolverConnector::getrs(cusolver_handle, trans, n, nrhs, A, lda, ipiv, B, ldb); + } +}; + + +// --- 3. Standard & Generalized Eigenvalue --- template struct lapack_heevd { using Real = typename GetTypeReal::type; void operator()( - const char& jobz, - const char& uplo, + const int dim, T* Mat, - const int& dim, + const int lda, Real* eigen_val) { - cuSolverConnector::heevd(cusolver_handle, jobz, uplo, dim, Mat, dim, eigen_val); + char jobz = 'V'; // Compute eigenvalues and eigenvectors + char uplo = 'U'; + cuSolverConnector::heevd(cusolver_handle, jobz, uplo, dim, Mat, lda, eigen_val); } }; @@ -198,49 +336,6 @@ struct lapack_hegvx { -template -struct lapack_getrf { - void operator()( - const int& m, - const int& n, - T* Mat, - const int& lda, - int* ipiv) - { - cuSolverConnector::getrf(cusolver_handle, m, n, Mat, lda, ipiv); - } -}; - -template -struct lapack_getri { - void operator()( - const int& n, - T* Mat, - const int& lda, - const int* ipiv, - T* work, - const int& lwork) - { - throw std::runtime_error("cuSOLVER does not provide LU-based matrix inversion interface (getri). To compute the inverse on GPU, use getrs instead."); - } -}; - -template -struct lapack_getrs { - void operator()( - const char& trans, - const int& n, - const int& nrhs, - T* A, - const int& lda, - const int* ipiv, - T* B, - const int& ldb) - { - cuSolverConnector::getrs(cusolver_handle, trans, n, nrhs, A, lda, ipiv, B, ldb); - } -}; - template struct set_matrix; template struct set_matrix; template struct set_matrix, DEVICE_GPU>; @@ -256,6 +351,13 @@ template struct lapack_potrf; template struct lapack_potrf, DEVICE_GPU>; template struct lapack_potrf, DEVICE_GPU>; + +template struct lapack_getrs; +template struct lapack_getrs; +template struct lapack_getrs, DEVICE_GPU>; +template struct lapack_getrs, DEVICE_GPU>; + + template struct lapack_heevd; template struct lapack_heevd; template struct lapack_heevd, DEVICE_GPU>; @@ -286,10 +388,10 @@ template struct lapack_getri; template struct lapack_getri, DEVICE_GPU>; template struct lapack_getri, DEVICE_GPU>; -template struct lapack_getrs; -template struct lapack_getrs; -template struct lapack_getrs, DEVICE_GPU>; -template struct lapack_getrs, DEVICE_GPU>; +template struct lapack_geqrf_inplace; +template struct lapack_geqrf_inplace; +template struct lapack_geqrf_inplace, DEVICE_GPU>; +template struct lapack_geqrf_inplace, DEVICE_GPU>; } // namespace kernels } // namespace container diff --git a/source/source_base/module_container/ATen/kernels/lapack.cpp b/source/source_base/module_container/ATen/kernels/lapack.cpp index 0c3e72d76c..2ab02f35c8 100644 --- a/source/source_base/module_container/ATen/kernels/lapack.cpp +++ b/source/source_base/module_container/ATen/kernels/lapack.cpp @@ -40,6 +40,7 @@ struct set_matrix { } }; +// --- 1. Matrix Decomposition --- template struct lapack_trtri { void operator()( @@ -73,16 +74,135 @@ struct lapack_potrf { } }; + +template +struct lapack_getrf { + void operator()( + const int& m, + const int& n, + T* Mat, + const int& lda, + int* ipiv) + { + int info = 0; + lapackConnector::getrf(m, n, Mat, lda, ipiv, info); + if (info != 0) { + throw std::runtime_error("getrf failed with info = " + std::to_string(info)); + } + } +}; + +template +struct lapack_getri { + void operator()( + const int& n, + T* Mat, + const int& lda, + const int* ipiv, + T* work, + const int& lwork) + { + int info = 0; + lapackConnector::getri(n, Mat, lda, ipiv, work, lwork, info); + if (info != 0) { + throw std::runtime_error("getri failed with info = " + std::to_string(info)); + } + } +}; + +template +struct lapack_geqrf_inplace { + void operator()( + const int m, + const int n, + T *A, + const int lda) + { + // Tensor or vector? + // 1. tau for storing the Householder reflectors + // tau should be dimension min(m, n) + int k = std::min(m, n); + Tensor tau(DataTypeToEnum::value, DeviceType::CpuDevice, {k}); + tau.zero(); + + int info = 0; + + // 2. query for workspace size + int lwork = -1; + T work_query; + lapackConnector::geqrf(m, n, A, lda, tau.data(), &work_query, lwork, info); + if (info != 0) { + throw std::runtime_error("geqrf workspace query failed with info = " + std::to_string(info)); + } + // allocate workspace + lwork = static_cast(get_real(work_query)); + Tensor work(DataTypeToEnum::value, DeviceType::CpuDevice, {lwork}); + work.zero(); + + // 3. perform QR decomposition + // and A is overwritten with upper R. + // Lower A + tau => Q + lapackConnector::geqrf(m, n, A, lda, tau.data(), work.data(), lwork, info); + if (info != 0) { + throw std::runtime_error("geqrf failed with info = " + std::to_string(info)); + } + + // 4. use orgqr to compute Q + // workspace query + lwork = -1; + lapackConnector::orgqr(m, n, k, A, lda, tau.data(), &work_query, lwork, info); + if (info != 0) { + throw std::runtime_error("orgqr workspace query failed with info = " + std::to_string(info)); + } + // allocate workspace + lwork = static_cast(get_real(work_query)); + work.resize({lwork}); + + // compute Q + lapackConnector::orgqr(m, n, k, A, lda, tau.data(), work.data(), lwork, info); + if (info != 0) { + throw std::runtime_error("orgqr failed with info = " + std::to_string(info)); + } + + // now, A should be overwritten with Q, columns orthogonal + + } +}; + +// --- 2. Linear System Solvers --- +template +struct lapack_getrs { + void operator()( + const char& trans, + const int& n, + const int& nrhs, + T* A, + const int& lda, + const int* ipiv, + T* B, + const int& ldb) + { + int info = 0; + lapackConnector::getrs(trans, n, nrhs, A, lda, ipiv, B, ldb, info); + if (info != 0) { + throw std::runtime_error("getrs failed with info = " + std::to_string(info)); + } + } +}; + + +// --- 3. Standard & Generalized Eigenvalue --- template struct lapack_heevd { using Real = typename GetTypeReal::type; void operator()( - const char& jobz, - const char& uplo, + const int dim, T* Mat, - const int& dim, + const int lda, Real* eigen_val) { + char jobz = 'V'; // Compute eigenvalues and eigenvectors + char uplo = 'U'; int info = 0; int lwork = std::max(2 * dim + dim * dim, 1 + 6 * dim + 2 * dim * dim); Tensor work(DataTypeToEnum::value, DeviceType::CpuDevice, {lwork}); @@ -96,7 +216,7 @@ struct lapack_heevd { Tensor iwork(DataTypeToEnum::value, DeviceType::CpuDevice, {liwork}); iwork.zero(); - lapackConnector::heevd(jobz, uplo, dim, Mat, dim, eigen_val, work.data(), lwork, rwork.data(), lrwork, iwork.data(), liwork, info); + lapackConnector::heevd(jobz, uplo, dim, Mat, lda, eigen_val, work.data(), lwork, rwork.data(), lrwork, iwork.data(), liwork, info); if (info != 0) { throw std::runtime_error("heevd failed with info = " + std::to_string(info)); } @@ -114,6 +234,8 @@ struct lapack_heevx { Real *eigen_val, T *eigen_vec) { + // copy Mat to aux, solve heevx(aux, eigen_val, eigen_vec) + // input Mat is not referenced in actual heevx LAPACK routines, and aux is destroyed. Tensor aux(DataTypeToEnum::value, DeviceType::CpuDevice, {n * lda}); // Copy Mat to aux since heevx will destroy it // aux = Mat @@ -338,60 +460,9 @@ struct lapack_hegvx { } }; -template -struct lapack_getrf { - void operator()( - const int& m, - const int& n, - T* Mat, - const int& lda, - int* ipiv) - { - int info = 0; - lapackConnector::getrf(m, n, Mat, lda, ipiv, info); - if (info != 0) { - throw std::runtime_error("getrf failed with info = " + std::to_string(info)); - } - } -}; -template -struct lapack_getri { - void operator()( - const int& n, - T* Mat, - const int& lda, - const int* ipiv, - T* work, - const int& lwork) - { - int info = 0; - lapackConnector::getri(n, Mat, lda, ipiv, work, lwork, info); - if (info != 0) { - throw std::runtime_error("getri failed with info = " + std::to_string(info)); - } - } -}; -template -struct lapack_getrs { - void operator()( - const char& trans, - const int& n, - const int& nrhs, - T* A, - const int& lda, - const int* ipiv, - T* B, - const int& ldb) - { - int info = 0; - lapackConnector::getrs(trans, n, nrhs, A, lda, ipiv, B, ldb, info); - if (info != 0) { - throw std::runtime_error("getrs failed with info = " + std::to_string(info)); - } - } -}; + template struct set_matrix; template struct set_matrix; @@ -408,6 +479,28 @@ template struct lapack_trtri; template struct lapack_trtri, DEVICE_CPU>; template struct lapack_trtri, DEVICE_CPU>; + +template struct lapack_getrf; +template struct lapack_getrf; +template struct lapack_getrf, DEVICE_CPU>; +template struct lapack_getrf, DEVICE_CPU>; + +template struct lapack_getri; +template struct lapack_getri; +template struct lapack_getri, DEVICE_CPU>; +template struct lapack_getri, DEVICE_CPU>; + + +template struct lapack_getrs; +template struct lapack_getrs; +template struct lapack_getrs, DEVICE_CPU>; +template struct lapack_getrs, DEVICE_CPU>; + +template struct lapack_geqrf_inplace; +template struct lapack_geqrf_inplace; +template struct lapack_geqrf_inplace, DEVICE_CPU>; +template struct lapack_geqrf_inplace, DEVICE_CPU>; + template struct lapack_heevd; template struct lapack_heevd; template struct lapack_heevd, DEVICE_CPU>; @@ -428,20 +521,5 @@ template struct lapack_hegvx; template struct lapack_hegvx, DEVICE_CPU>; template struct lapack_hegvx, DEVICE_CPU>; -template struct lapack_getrf; -template struct lapack_getrf; -template struct lapack_getrf, DEVICE_CPU>; -template struct lapack_getrf, DEVICE_CPU>; - -template struct lapack_getri; -template struct lapack_getri; -template struct lapack_getri, DEVICE_CPU>; -template struct lapack_getri, DEVICE_CPU>; - -template struct lapack_getrs; -template struct lapack_getrs; -template struct lapack_getrs, DEVICE_CPU>; -template struct lapack_getrs, DEVICE_CPU>; - } // namespace kernels } // namespace container diff --git a/source/source_base/module_container/ATen/kernels/lapack.h b/source/source_base/module_container/ATen/kernels/lapack.h index b3a4a40c4e..117f8ef24b 100644 --- a/source/source_base/module_container/ATen/kernels/lapack.h +++ b/source/source_base/module_container/ATen/kernels/lapack.h @@ -20,6 +20,7 @@ struct set_matrix { }; +// --- 1. Matrix Decomposition --- template struct lapack_trtri { void operator()( @@ -40,6 +41,96 @@ struct lapack_potrf { const int& lda); }; +template +struct lapack_getrf { + void operator()( + const int& m, + const int& n, + T* Mat, + const int& lda, + int* ipiv); +}; + + +template +struct lapack_getri { + void operator()( + const int& n, + T* Mat, + const int& lda, + const int* ipiv, + T* work, + const int& lwork); +}; + +// This is QR factorization in-place +// that will change input Mat A to orthogonal/unitary matrix Q +template +struct lapack_geqrf_inplace { + /** + * @brief Perform in-place QR factorization of a matrix using LAPACK's geqrf function. + * + * This function computes the QR factorization of an m-by-n matrix A as A = Q * R, + * where Q is an orthogonal/unitary matrix and R is an upper triangular matrix. + * The factorization is performed in-place, meaning the input matrix A will be modified. + * + * On exit: A is overwritten with the QR factorization Q orthogonal/unitary matrix + * + * @param m The number of rows in the matrix A. m >= 0 + * @param n The number of columns in the matrix A. n >= 0 + * @param A Pointer to the matrix A to be factorized. On exit, contains the QR factorization + * @param lda The leading dimension of the matrix A. lda >= max(1, m) + */ + void operator()( + const int m, + const int n, + T *A, + const int lda); +}; + +// This is QR factorization +// where [in]Mat will be kept and the results are stored in separate matrix Q +// template +// struct lapack_geqrf{ +// /** +// * Perform QR factorization of a matrix using LAPACK's geqrf function. +// * +// * @param m The number of rows in the matrix. +// * @param n The number of columns in the matrix. +// * @param Mat The matrix to be factorized. +// * On exit, the upper triangle contains the upper triangular matrix R, +// * and the elements below the diagonal, with the array TAU, represent +// * the unitary matrix Q as a product of min(m,n) elementary reflectors. +// * @param lda The leading dimension of the matrix. +// * @param tau Array of size min(m,n) containing the Householder reflectors. +// */ +// void operator()( +// const int m, +// const int n, +// T *Mat, +// const int lda, +// T *tau); +// }; + + +// --- 2. Linear System Solvers --- +template +struct lapack_getrs { + void operator()( + const char& trans, + const int& n, + const int& nrhs, + T* A, + const int& lda, + const int* ipiv, + T* B, + const int& ldb); +}; + + + +// --- 3. Standard & Generalized Eigenvalue --- + // ============================================================================ // Standard Hermitian Eigenvalue Problem Solvers // ============================================================================ @@ -54,12 +145,37 @@ struct lapack_potrf { // ============================================================================ template struct lapack_heevd { + // !> ZHEEVD computes all eigenvalues and, optionally, eigenvectors of a + // !> complex Hermitian matrix A. If eigenvectors are desired, it uses a + // !> divide and conquer algorithm. + // !> On exit, if JOBZ = 'V', then if INFO = 0, A contains the + // !> orthonormal eigenvectors of the matrix A. + /** + * @brief Computes all eigenvalues and, optionally, eigenvectors of a complex Hermitian matrix. + * + * This function solves the standard Hermitian eigenvalue problem A*x = lambda*x, + * where A is a Hermitian matrix. It computes all eigenvalues and optionally + * the corresponding eigenvectors using a divide and conquer algorithm. + * + * @param[in] dim The order of the matrix A. dim >= 0. + * @param[in,out] Mat On entry, the Hermitian matrix A. + * On exit, if eigenvectors are computed, A contains the + * orthonormal eigenvectors of the matrix A. + * @param[in] lda The leading dimension of the array Mat. lda >= max(1, dim). + * @param[out] eigen_val Array of size at least dim. On normal exit, contains the + * eigenvalues in ascending order. + * + * @note + * See LAPACK ZHEEVD or CHEEVD documentation for more details. + * The matrix is assumed to be stored in upper or lower triangular form + * according to the uplo parameter (not shown here but typically passed + * to the actual implementation). + */ using Real = typename GetTypeReal::type; void operator()( - const char& jobz, - const char& uplo, + const int dim, T* Mat, - const int& dim, + const int lda, Real* eigen_val); }; @@ -74,7 +190,8 @@ struct lapack_heevx { * * @param dim The order of the matrix A. dim >= 0. * @param lda The leading dimension of the array Mat. lda >= max(1, dim). - * @param Mat On entry, the Hermitian matrix A. On exit, A is kept. + * @param[in] Mat On entry, the Hermitian matrix A. On exit, A is kept. + * Only used to provide values of matrix. * @param neig The number of eigenvalues to be found. 0 <= neig <= dim. * @param eigen_val On normal exit, the first \p neig elements contain the selected * eigenvalues in ascending order. @@ -174,41 +291,6 @@ struct lapack_hegvx { }; -template -struct lapack_getrf { - void operator()( - const int& m, - const int& n, - T* Mat, - const int& lda, - int* ipiv); -}; - - -template -struct lapack_getri { - void operator()( - const int& n, - T* Mat, - const int& lda, - const int* ipiv, - T* work, - const int& lwork); -}; - -template -struct lapack_getrs { - void operator()( - const char& trans, - const int& n, - const int& nrhs, - T* A, - const int& lda, - const int* ipiv, - T* B, - const int& ldb); -}; - #if defined(__CUDA) || defined(__ROCM) // TODO: Use C++ singleton to manage the GPU handles void createGpuSolverHandle(); // create cusolver handle diff --git a/source/source_base/module_container/ATen/kernels/rocm/blas.hip.cu b/source/source_base/module_container/ATen/kernels/rocm/blas.hip.cu index 5ad275460c..9fa7f63f08 100644 --- a/source/source_base/module_container/ATen/kernels/rocm/blas.hip.cu +++ b/source/source_base/module_container/ATen/kernels/rocm/blas.hip.cu @@ -23,6 +23,19 @@ void destroyGpuBlasHandle() { } +template +struct blas_nrm2 { + T operator()( + const int n, + const T *x, + const int incx) + { + T result; + hipBlasConnector::nrm2(hipblas_handle, n, x, incx, &result); + return result; + } +}; + template struct blas_dot { void operator()( @@ -196,6 +209,11 @@ struct blas_gemm_batched_strided { }; // Explicitly instantiate functors for the types of functor registered. +template struct blas_nrm2; +template struct blas_nrm2; +template struct blas_nrm2 , DEVICE_GPU>; +template struct blas_nrm2, DEVICE_GPU>; + template struct blas_dot; template struct blas_dot; template struct blas_dot , DEVICE_GPU>; diff --git a/source/source_base/module_container/ATen/kernels/test/blas_test.cpp b/source/source_base/module_container/ATen/kernels/test/blas_test.cpp index a01ef8beb9..d0c53422d4 100644 --- a/source/source_base/module_container/ATen/kernels/test/blas_test.cpp +++ b/source/source_base/module_container/ATen/kernels/test/blas_test.cpp @@ -20,6 +20,39 @@ class BlasTest : public testing::Test { TYPED_TEST_SUITE(BlasTest, base::utils::Types); +TYPED_TEST(BlasTest, Copy) { + using Type = typename std::tuple_element<0, decltype(TypeParam())>::type; + using Device = typename std::tuple_element<1, decltype(TypeParam())>::type; + + blas_copy copyCalculator; + + const int n = 3; + const Tensor x = std::move(Tensor({static_cast(1.0), static_cast(2.0), static_cast(3.0)}).to_device()); + Tensor y = std::move(Tensor({static_cast(0.0), static_cast(0.0), static_cast(0.0)}).to_device()); + + copyCalculator(n, x.data(), 1, y.data(), 1); + const Tensor expected = std::move(Tensor({static_cast(1.0), static_cast(2.0), static_cast(3.0)}).to_device()); + + EXPECT_EQ(y, expected); +} + +TYPED_TEST(BlasTest, Nrm2) { + using Type = typename std::tuple_element<0, decltype(TypeParam())>::type; + using Device = typename std::tuple_element<1, decltype(TypeParam())>::type; + + blas_nrm2 nrm2Calculator; + + const int n = 3; + const Tensor x = std::move(Tensor({static_cast(3.0), static_cast(4.0), static_cast(0.0)}).to_device()); + + using Real = typename GetTypeReal::type; + Real result = {}; + result = nrm2Calculator(n, x.data(), 1); + const Real expected = static_cast(5.0); + + EXPECT_NEAR(result, expected, static_cast(1e-6)); +} + TYPED_TEST(BlasTest, Dot) { using Type = typename std::tuple_element<0, decltype(TypeParam())>::type; using Device = typename std::tuple_element<1, decltype(TypeParam())>::type; @@ -29,7 +62,7 @@ TYPED_TEST(BlasTest, Dot) { const int n = 3; const Tensor x = std::move(Tensor({static_cast(1.0), static_cast(2.0), static_cast(3.0)}).to_device()); const Tensor y = std::move(Tensor({static_cast(4.0), static_cast(5.0), static_cast(6.0)}).to_device()); - + Type result = {}; dotCalculator(n, x.data(), 1, y.data(), 1, &result); const Type expected = static_cast(32.0); @@ -46,7 +79,7 @@ TYPED_TEST(BlasTest, Scal) { const int n = 3; const Type alpha = static_cast(2.0); Tensor x = std::move(Tensor({static_cast(1.0), static_cast(2.0), static_cast(3.0)}).to_device()); - + scalCalculator(n, &alpha, x.data(), 1); const Tensor expected = std::move(Tensor({static_cast(2.0), static_cast(4.0), static_cast(6.0)}).to_device()); @@ -64,7 +97,7 @@ TYPED_TEST(BlasTest, Axpy) { const Type alpha = static_cast(2.0); const Tensor x = std::move(Tensor({static_cast(1.0), static_cast(2.0), static_cast(3.0)}).to_device()); Tensor y = std::move(Tensor({static_cast(4.0), static_cast(5.0), static_cast(6.0)}).to_device()); - + axpyCalculator(n, &alpha, x.data(), 1, y.data(), 1); const Tensor expected = std::move(Tensor({static_cast(6.0), static_cast(9.0), static_cast(12.0)}).to_device()); @@ -83,11 +116,11 @@ TYPED_TEST(BlasTest, Gemv) { const int n = 2; const Type alpha = static_cast(2.0); const Type beta = static_cast(3.0); - const Tensor A = std::move(Tensor({static_cast(1.0), static_cast(2.0), static_cast(3.0), + const Tensor A = std::move(Tensor({static_cast(1.0), static_cast(2.0), static_cast(3.0), static_cast(4.0), static_cast(5.0), static_cast(6.0)}).to_device()); const Tensor x = std::move(Tensor({static_cast(1.0), static_cast(2.0)}).to_device()); Tensor y = std::move(Tensor({static_cast(1.0), static_cast(2.0), static_cast(3.0)}).to_device()); - + gemvCalculator(trans, m, n, &alpha, A.data(), m, x.data(), 1, &beta, y.data(), 1); const Tensor expected = std::move(Tensor({static_cast(21.0), static_cast(30.0), static_cast(39.0)}).to_device()); @@ -114,14 +147,14 @@ TYPED_TEST(BlasTest, GemvBatched) { std::vector y = {}; const Tensor _A = std::move(Tensor({ - static_cast(1.0), static_cast(2.0), - static_cast(3.0), static_cast(4.0), + static_cast(1.0), static_cast(2.0), + static_cast(3.0), static_cast(4.0), static_cast(5.0), static_cast(6.0), - + static_cast(7.0), static_cast(8.0), static_cast(9.0), static_cast(10.0), static_cast(11.0),static_cast(12.0)}).to_device()); - + A.push_back(_A.data()); A.push_back(_A.data() + m * n); @@ -164,14 +197,14 @@ TYPED_TEST(BlasTest, GemvBatchedStrided) { std::vector y = {}; const Tensor _A = std::move(Tensor({ - static_cast(1.0), static_cast(2.0), - static_cast(3.0), static_cast(4.0), + static_cast(1.0), static_cast(2.0), + static_cast(3.0), static_cast(4.0), static_cast(5.0), static_cast(6.0), - + static_cast(7.0), static_cast(8.0), static_cast(9.0), static_cast(10.0), static_cast(11.0),static_cast(12.0)}).to_device()); - + A.push_back(_A.data()); A.push_back(_A.data() + m * n); @@ -205,11 +238,11 @@ TYPED_TEST(BlasTest, Gemm) { const int n = 2; const Type alpha = static_cast(2.0); const Type beta = static_cast(3.0); - const Tensor A = std::move(Tensor({static_cast(1.0), static_cast(2.0), static_cast(3.0), + const Tensor A = std::move(Tensor({static_cast(1.0), static_cast(2.0), static_cast(3.0), static_cast(4.0), static_cast(5.0), static_cast(6.0)}).to_device()); const Tensor x = std::move(Tensor({static_cast(1.0), static_cast(2.0)}).to_device()); Tensor y = std::move(Tensor({static_cast(1.0), static_cast(2.0), static_cast(3.0)}).to_device()); - + gemmCalculator(trans, trans, m, 1, n, &alpha, A.data(), m, x.data(), n, &beta, y.data(), m); const Tensor expected = std::move(Tensor({static_cast(21.0), static_cast(30.0), static_cast(39.0)}).to_device()); @@ -237,14 +270,14 @@ TYPED_TEST(BlasTest, GemmBatched) { std::vector y2 = {}; const Tensor _A = std::move(Tensor({ - static_cast(1.0), static_cast(2.0), - static_cast(3.0), static_cast(4.0), + static_cast(1.0), static_cast(2.0), + static_cast(3.0), static_cast(4.0), static_cast(5.0), static_cast(6.0), - + static_cast(7.0), static_cast(8.0), static_cast(9.0), static_cast(10.0), static_cast(11.0),static_cast(12.0)}).to_device()); - + A.push_back(_A.data()); A.push_back(_A.data() + m * n); @@ -287,14 +320,14 @@ TYPED_TEST(BlasTest, GemmBatchedStrided) { std::vector y2 = {}; const Tensor _A = std::move(Tensor({ - static_cast(1.0), static_cast(2.0), - static_cast(3.0), static_cast(4.0), + static_cast(1.0), static_cast(2.0), + static_cast(3.0), static_cast(4.0), static_cast(5.0), static_cast(6.0), - + static_cast(7.0), static_cast(8.0), static_cast(9.0), static_cast(10.0), static_cast(11.0),static_cast(12.0)}).to_device()); - + A.push_back(_A.data()); A.push_back(_A.data() + m * n); diff --git a/source/source_base/module_container/ATen/kernels/test/lapack_test.cpp b/source/source_base/module_container/ATen/kernels/test/lapack_test.cpp index acd903bb60..5524ca6c50 100644 --- a/source/source_base/module_container/ATen/kernels/test/lapack_test.cpp +++ b/source/source_base/module_container/ATen/kernels/test/lapack_test.cpp @@ -92,6 +92,83 @@ TYPED_TEST(LapackTest, Potrf) { EXPECT_EQ(A, C); } +// lapack_geqrf_inplace, +// check that QtQ = I +TYPED_TEST(LapackTest, GeqrfInPlace) { + using Type = typename std::tuple_element<0, decltype(TypeParam())>::type; + using Device = typename std::tuple_element<1, decltype(TypeParam())>::type; + + lapack_geqrf_inplace geqrfCalculator; + + const int m = 4; + const int n = 3; // m >= n,Q is m x n column-orthogonal matrix + const int lda = m; + + Tensor A_input = std::move(Tensor({ + static_cast(1.0), static_cast(2.0), static_cast(3.0), static_cast(4.0), + static_cast(5.0), static_cast(6.0), static_cast(7.0), static_cast(8.0), + static_cast(9.0), static_cast(10.0), static_cast(11.0), static_cast(12.0) + }).to_device()); + + Tensor A = A_input; // will be overwritten as Q + + // do geqrf -> get orthogonal Q + geqrfCalculator(m, n, A.data(), lda); + + // check on CPU + Tensor Q = A.to_device(); + const Type* Q_data = Q.data(); + + // compute QtQ = Q^T * Q (n x n) + Tensor QtQ = Q; // std::move(Tensor(std::vector(n * n, static_cast(0.0))).to_device()); + const Type alpha = static_cast(1.0); + const Type beta = static_cast(0.0); + + blas_gemm gemm; + gemm('C', 'N', // Q^T * Q + n, n, m, // n x n + &alpha, + Q_data, lda, // Q^T + Q_data, lda, // Q + &beta, + QtQ.data(), n); + + // To print value: first to_device CPU, then print + // // Test code: print A + // std::cout << "A = " << std::endl; + // for (int i = 0; i < m; ++i) { + // for (int j = 0; j < n; ++j) { + // std::cout << A_input.to_device().data()[i + j * m] << " "; + // } + // std::cout << std::endl; + // } + // // Test code: print Q + // std::cout << "Q = " << std::endl; + // for (int i = 0; i < m; ++i) { + // for (int j = 0; j < n; ++j) { + // std::cout << Q.data()[i + j * m] << " "; + // } + // std::cout << std::endl; + // } + // // Test code: print QtQ + // std::cout << "QtQ = " << std::endl; + // for (int i = 0; i < n; ++i) { + // for (int j = 0; j < n; ++j) { + // std::cout << QtQ.data()[i + j * n] << " "; + // } + // std::cout << std::endl; + // } + + // check QtQ + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + Type expected = (i == j) ? static_cast(1.0) : static_cast(0.0); + EXPECT_NEAR(std::abs(QtQ.data()[i + j * n]), std::abs(expected), 1e-5) + << "Q^T * Q not identity at (" << i << "," << j << ")"; + } + } +} + // Test for lapack_heevd and lapack_heevx: // Solve a standard eigenvalue problem // and check that A*V = V*E @@ -124,7 +201,8 @@ TYPED_TEST(LapackTest, heevd) { const Type beta = static_cast(0.0); // Note all blas and lapack operators within container are column major! // For this reason, we should employ 'L' instead of 'U' in the subsequent line. - heevdCalculator('V', 'U', B.data(), dim, E.data()); + // heevdCalculator('V', 'U', B.data(), dim, E.data()); + heevdCalculator(dim, B.data(), dim, E.data()); E = E.to_device(); const Tensor Alpha = std::move(Tensor({ diff --git a/source/source_base/module_container/base/third_party/blas.h b/source/source_base/module_container/base/third_party/blas.h index b36df5a39b..1fdbac67b2 100644 --- a/source/source_base/module_container/base/third_party/blas.h +++ b/source/source_base/module_container/base/third_party/blas.h @@ -25,14 +25,17 @@ void daxpy_(const int *N, const double *alpha, const double *x, const int *incx, void caxpy_(const int *N, const std::complex *alpha, const std::complex *x, const int *incx, std::complex *y, const int *incy); void zaxpy_(const int *N, const std::complex *alpha, const std::complex *x, const int *incx, std::complex *y, const int *incy); -void dcopy_(const int *n, const double *a, const int *incx, double *b, const int *incy); -void zcopy_(const int *n, const std::complex *a, const int *incx, std::complex *b, const int *incy); +void scopy_(const int *n, const float *a, const int *incx, float *b, int const *incy); +void dcopy_(const int *n, const double *a, const int *incx, double *b, int const *incy); +void ccopy_(const int *n, const std::complex *a, const int *incx, std::complex *b, int const *incy); +void zcopy_(const int *n, const std::complex *a, const int *incx, std::complex *b, int const *incy); + //reason for passing results as argument instead of returning it: //see https://www.numbercrunch.de/blog/2014/07/lost-in-translation/ -void cdotc_(const int *n, const std::complex *zx, const int *incx, +void cdotc_(const int *n, const std::complex *zx, const int *incx, const std::complex *zy, const int *incy, std::complex *result); -void zdotc_(const int *n, const std::complex *zx, const int *incx, +void zdotc_(const int *n, const std::complex *zx, const int *incx, const std::complex *zy, const int *incy, std::complex *result); // Peize Lin add ?dot 2017-10-27, to compute d=x*y float sdot_(const int *N, const float *x, const int *incx, const float *y, const int *incy); @@ -41,6 +44,7 @@ double ddot_(const int *N, const double *x, const int *incx, const double *y, co // Peize Lin add ?nrm2 2018-06-12, to compute out = ||x||_2 = \sqrt{ \sum_i x_i**2 } float snrm2_( const int *n, const float *x, const int *incx ); double dnrm2_( const int *n, const double *x, const int *incx ); +float scnrm2_( const int *n, const std::complex *x, const int *incx ); double dznrm2_( const int *n, const std::complex *x, const int *incx ); // level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work. @@ -344,6 +348,11 @@ double nrm2( const int n, const double *x, const int incx ) return dnrm2_( &n, x, &incx ); } static inline +double nrm2( const int n, const std::complex *x, const int incx ) +{ + return scnrm2_( &n, x, &incx ); +} +static inline double nrm2( const int n, const std::complex *x, const int incx ) { return dznrm2_( &n, x, &incx ); @@ -351,11 +360,22 @@ double nrm2( const int n, const std::complex *x, const int incx ) // copies a into b static inline +void copy(const int n, const float *a, const int incx, float *b, const int incy) +{ + scopy_(&n, a, &incx, b, &incy); +} +static inline void copy(const int n, const double *a, const int incx, double *b, const int incy) + { dcopy_(&n, a, &incx, b, &incy); } static inline +void copy(const int n, const std::complex *a, const int incx, std::complex *b, const int incy) +{ + ccopy_(&n, a, &incx, b, &incy); +} +static inline void copy(const int n, const std::complex *a, const int incx, std::complex *b, const int incy) { zcopy_(&n, a, &incx, b, &incy); diff --git a/source/source_base/module_container/base/third_party/cublas.h b/source/source_base/module_container/base/third_party/cublas.h index 34cd7484d1..cea046f30e 100644 --- a/source/source_base/module_container/base/third_party/cublas.h +++ b/source/source_base/module_container/base/third_party/cublas.h @@ -8,6 +8,48 @@ namespace container { namespace cuBlasConnector { +static inline +void copy(cublasHandle_t& handle, const int& n, const float *x, const int& incx, float *y, const int& incy) +{ + cublasErrcheck(cublasScopy(handle, n, x, incx, y, incy)); +} +static inline +void copy(cublasHandle_t& handle, const int& n, const double *x, const int& incx, double *y, const int& incy) +{ + cublasErrcheck(cublasDcopy(handle, n, x, incx, y, incy)); +} +static inline +void copy(cublasHandle_t& handle, const int& n, const std::complex *x, const int& incx, std::complex *y, const int& incy) +{ + cublasErrcheck(cublasCcopy(handle, n, reinterpret_cast(x), incx, reinterpret_cast(y), incy)); +} +static inline +void copy(cublasHandle_t& handle, const int& n, const std::complex *x, const int& incx, std::complex *y, const int& incy) +{ + cublasErrcheck(cublasZcopy(handle, n, reinterpret_cast(x), incx, reinterpret_cast(y), incy)); +} + +static inline +void nrm2(cublasHandle_t& handle, const int& n, const float *x, const int& incx, float* result) +{ + cublasErrcheck(cublasSnrm2(handle, n, x, incx, result)); +} +static inline +void nrm2(cublasHandle_t& handle, const int& n, const double *x, const int& incx, double* result) +{ + cublasErrcheck(cublasDnrm2(handle, n, x, incx, result)); +} +static inline +void nrm2(cublasHandle_t& handle, const int& n, const std::complex *x, const int& incx, float* result) +{ + cublasErrcheck(cublasScnrm2(handle, n, reinterpret_cast(x), incx, result)); +} +static inline +void nrm2(cublasHandle_t& handle, const int& n, const std::complex *x, const int& incx, double* result) +{ + cublasErrcheck(cublasDznrm2(handle, n, reinterpret_cast(x), incx, result)); +} + static inline void dot(cublasHandle_t& handle, const int& n, const float *x, const int& incx, const float *y, const int& incy, float* result) { @@ -90,7 +132,7 @@ void gemv(cublasHandle_t& handle, const char& trans, const int& m, const int& n, const std::complex& alpha, const std::complex *A, const int& lda, const std::complex *x, const int& incx, const std::complex& beta, std::complex *y, const int& incy) { - cublasErrcheck(cublasCgemv(handle, GetCublasOperation(trans), m, n, reinterpret_cast(&alpha), + cublasErrcheck(cublasCgemv(handle, GetCublasOperation(trans), m, n, reinterpret_cast(&alpha), reinterpret_cast(A), lda, reinterpret_cast(x), incx, reinterpret_cast(&beta), reinterpret_cast(y), incy)); } static inline @@ -98,7 +140,7 @@ void gemv(cublasHandle_t& handle, const char& trans, const int& m, const int& n, const std::complex& alpha, const std::complex *A, const int& lda, const std::complex *x, const int& incx, const std::complex& beta, std::complex *y, const int& incy) { - cublasErrcheck(cublasZgemv(handle, GetCublasOperation(trans), m, n, reinterpret_cast(&alpha), + cublasErrcheck(cublasZgemv(handle, GetCublasOperation(trans), m, n, reinterpret_cast(&alpha), reinterpret_cast(A), lda, reinterpret_cast(x), incx, reinterpret_cast(&beta), reinterpret_cast(y), incy)); } @@ -148,11 +190,11 @@ void gemm(cublasHandle_t& handle, const char& transa, const char& transb, const const std::complex& beta, std::complex* C, const int& ldc) { cublasErrcheck(cublasCgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb), - m, n, k, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(&beta), + m, n, k, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb, + reinterpret_cast(&beta), reinterpret_cast(C), ldc)); } static inline @@ -162,15 +204,15 @@ void gemm(cublasHandle_t& handle, const char& transa, const char& transb, const { cublasErrcheck(cublasZgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb), m, n, k, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(&beta), + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb, + reinterpret_cast(&beta), reinterpret_cast(C), ldc)); } template -static inline +static inline T** allocate_(T** in, const int& batch_size) { T** out = nullptr; @@ -216,11 +258,11 @@ void gemm_batched(cublasHandle_t& handle, const char& transa, const char& transb std::complex** d_B = allocate_(B, batch_size); std::complex** d_C = allocate_(C, batch_size); cublasErrcheck(cublasCgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb), - m, n, k, - reinterpret_cast(&alpha), - reinterpret_cast(d_A), lda, - reinterpret_cast(d_B), ldb, - reinterpret_cast(&beta), + m, n, k, + reinterpret_cast(&alpha), + reinterpret_cast(d_A), lda, + reinterpret_cast(d_B), ldb, + reinterpret_cast(&beta), reinterpret_cast(d_C), ldc, batch_size)); cudaErrcheck(cudaFree(d_A)); cudaErrcheck(cudaFree(d_B)); @@ -235,11 +277,11 @@ void gemm_batched(cublasHandle_t& handle, const char& transa, const char& transb std::complex** d_B = allocate_(B, batch_size); std::complex** d_C = allocate_(C, batch_size); cublasErrcheck(cublasZgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb), - m, n, k, - reinterpret_cast(&alpha), - reinterpret_cast(d_A), lda, - reinterpret_cast(d_B), ldb, - reinterpret_cast(&beta), + m, n, k, + reinterpret_cast(&alpha), + reinterpret_cast(d_A), lda, + reinterpret_cast(d_B), ldb, + reinterpret_cast(&beta), reinterpret_cast(d_C), ldc, batch_size)); cudaErrcheck(cudaFree(d_A)); cudaErrcheck(cudaFree(d_B)); @@ -252,14 +294,14 @@ void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char const float& beta, float* C, const int& ldc, const int& stride_c, const int& batch_size) { cublasErrcheck(cublasSgemmStridedBatched( - handle, - GetCublasOperation(transa), + handle, + GetCublasOperation(transa), GetCublasOperation(transb), - m, n, k, - &alpha, - A, lda, stride_a, - B, ldb, stride_b, - &beta, + m, n, k, + &alpha, + A, lda, stride_a, + B, ldb, stride_b, + &beta, C, ldc, stride_c, batch_size)); } @@ -269,14 +311,14 @@ void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char const double& beta, double* C, const int& ldc, const int& stride_c, const int& batch_size) { cublasErrcheck(cublasDgemmStridedBatched( - handle, - GetCublasOperation(transa), + handle, + GetCublasOperation(transa), GetCublasOperation(transb), - m, n, k, - &alpha, - A, lda, stride_a, - B, ldb, stride_b, - &beta, + m, n, k, + &alpha, + A, lda, stride_a, + B, ldb, stride_b, + &beta, C, ldc, stride_c, batch_size)); } @@ -286,14 +328,14 @@ void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char const std::complex& beta, std::complex* C, const int& ldc, const int& stride_c, const int& batch_size) { cublasErrcheck(cublasCgemmStridedBatched( - handle, - GetCublasOperation(transa), + handle, + GetCublasOperation(transa), GetCublasOperation(transb), - m, n, k, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, stride_a, - reinterpret_cast(B), ldb, stride_b, - reinterpret_cast(&beta), + m, n, k, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, stride_a, + reinterpret_cast(B), ldb, stride_b, + reinterpret_cast(&beta), reinterpret_cast(C), ldc, stride_c, batch_size)); } @@ -303,14 +345,14 @@ void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char const std::complex& beta, std::complex* C, const int& ldc, const int& stride_c, const int& batch_size) { cublasErrcheck(cublasZgemmStridedBatched( - handle, - GetCublasOperation(transa), + handle, + GetCublasOperation(transa), GetCublasOperation(transb), - m, n, k, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, stride_a, - reinterpret_cast(B), ldb, stride_b, - reinterpret_cast(&beta), + m, n, k, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, stride_a, + reinterpret_cast(B), ldb, stride_b, + reinterpret_cast(&beta), reinterpret_cast(C), ldc, stride_c, batch_size)); } @@ -318,4 +360,4 @@ void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char } // namespace cuBlasConnector } // namespace container -#endif // BASE_THIRD_PARTY_CUBLAS_H_ \ No newline at end of file +#endif // BASE_THIRD_PARTY_CUBLAS_H_ diff --git a/source/source_base/module_container/base/third_party/cusolver.h b/source/source_base/module_container/base/third_party/cusolver.h index affdc4ca48..01fdda7edb 100644 --- a/source/source_base/module_container/base/third_party/cusolver.h +++ b/source/source_base/module_container/base/third_party/cusolver.h @@ -3,6 +3,16 @@ #include #include + +// #include // traits, needed if generic API is used. +// header provided by cusolver, including some data types and macros. +// see https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuSOLVER/utils/cusolver_utils.h +// The cuSolverDN library provides two different APIs; legacy and generic. +// https://docs.nvidia.com/cuda/cusolver/index.html#naming-conventions +// now only legacy APIs are used, while the general APIs have the potential to simplify code implementation. +// for example, cucusolverDnXpotrf/getrf/geqrf/sytrf +// More tests are needed to confirm that the generic APIs are operating normally, as they are not yet fully supported. + #include namespace container { @@ -1136,6 +1146,433 @@ void getrs(cusolverDnHandle_t& cusolver_handle, const char& trans, const int& n, cudaErrcheck(cudaFree(d_info)); } +// QR decomposition +// geqrf, orgqr +// Note: +// there are two cusolver geqrf +// one is cusolverDngeqrf +// one is cusolverDnXgeqrf +// which one is better? +// +// template +// static inline void geqrf( +// cusolverDnHandle_t& cusolver_handle, +// const int64_t m, +// const int64_t n, +// T* d_A, // device matrix A (m x n, column-major) +// const int64_t lda, +// T* d_tau // output: scalar factors of elementary reflectors +// ) { +// // query workspace size +// int *d_info = nullptr; /* error info */ +// +// size_t workspaceInBytesOnDevice = 0; /* size of workspace */ +// void *d_work = nullptr; /* device workspace */ +// size_t workspaceInBytesOnHost = 0; /* size of workspace */ +// void *h_work = nullptr; /* host workspace */ +// +// cudaErrcheck(cudaMalloc(reinterpret_cast(&d_info), sizeof(int))); +// +// cusolverDnParams_t params = NULL; +// cusolverErrcheck(cusolverDnCreateParams(¶ms)); +// +// cusolverErrcheck(cusolverDnXgeqrf_bufferSize( +// cusolver_handle, +// params, +// m, n, +// traits::cuda_data_type, +// d_A, +// lda, +// traits::cuda_data_type, +// d_tau, +// traits::cuda_data_type, +// &workspaceInBytesOnDevice, +// &workspaceInBytesOnHost +// )); +// +// // allocate device workspace +// cudaErrcheck(cudaMalloc(reinterpret_cast(&d_work), workspaceInBytesOnDevice)); +// +// // allocate host workspace +// if (workspaceInBytesOnHost > 0) { +// h_work = reinterpret_cast(malloc(workspaceInBytesOnHost)); +// if (h_work == nullptr) { +// throw std::runtime_error("Error: h_work not allocated."); +// } +// } +// +// // QR factorization +// cusolverErrcheck(cusolverDnXgeqrf( +// cusolver_handle, +// params, +// m, n, +// traits::cuda_data_type, +// d_A, +// lda, +// traits::cuda_data_type, +// d_tau, +// traits::cuda_data_type, +// d_work, +// workspaceInBytesOnDevice, +// h_work, +// workspaceInBytesOnHost, +// d_info +// )); +// +// // check info +// int h_info = 0; +// cudaErrcheck(cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost)); +// if (h_info != 0) { +// // std::printf("%d-th parameter is wrong \n", -info); +// // print error message +// std::cout << -h_info << "th parameter is wrong" << std::endl; +// throw std::runtime_error("geqrf: failed to compute QR decomposition"); +// } +// +// // clean workspace +// cudaErrcheck(cudaFree(d_info)); +// cudaErrcheck(cudaFree(d_work)); +// if (h_work) free(h_work); +// cusolverErrcheck(cusolverDnDestroyParams(params)); +// } + +// geqrf + +// --- float --- +static inline void geqrf( + cusolverDnHandle_t& cusolver_handle, + const int m, + const int n, + float* d_A, + const int lda, + float* d_tau +) { + int lwork = 0; + cusolverErrcheck(cusolverDnSgeqrf_bufferSize( + cusolver_handle, m, n, d_A, lda, &lwork)); + + float* d_work = nullptr; + int* d_info = nullptr; + + if (lwork > 0) { + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_work), sizeof(float) * lwork)); + } + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_info), sizeof(int))); + + cusolverErrcheck(cusolverDnSgeqrf( + cusolver_handle, m, n, d_A, lda, d_tau, d_work, lwork, d_info)); + + int h_info = 0; + cudaErrcheck(cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost)); + if (h_info != 0) { + std::cout << "geqrf (S): info = " << h_info << std::endl; + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); + throw std::runtime_error("geqrf (S): QR factorization failed"); + } + + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); +} + +// --- double --- +static inline void geqrf( + cusolverDnHandle_t& cusolver_handle, + const int m, + const int n, + double* d_A, + const int lda, + double* d_tau +) { + int lwork = 0; + cusolverErrcheck(cusolverDnDgeqrf_bufferSize( + cusolver_handle, m, n, d_A, lda, &lwork)); + + double* d_work = nullptr; + int* d_info = nullptr; + + if (lwork > 0) { + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_work), sizeof(double) * lwork)); + } + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_info), sizeof(int))); + + cusolverErrcheck(cusolverDnDgeqrf( + cusolver_handle, m, n, d_A, lda, d_tau, d_work, lwork, d_info)); + + int h_info = 0; + cudaErrcheck(cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost)); + if (h_info != 0) { + std::cout << "geqrf (D): info = " << h_info << std::endl; + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); + throw std::runtime_error("geqrf (D): QR factorization failed"); + } + + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); +} + +// --- std::complex --- +static inline void geqrf( + cusolverDnHandle_t& cusolver_handle, + const int m, + const int n, + std::complex* d_A, + const int lda, + std::complex* d_tau +) { + int lwork = 0; + cusolverErrcheck(cusolverDnCgeqrf_bufferSize( + cusolver_handle, m, n, + reinterpret_cast(d_A), + lda, + &lwork // ← 这里才是 lwork 的地址! + )); + + cuComplex* d_work = nullptr; + int* d_info = nullptr; + + if (lwork > 0) { + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_work), sizeof(cuComplex) * lwork)); + } + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_info), sizeof(int))); + + cusolverErrcheck(cusolverDnCgeqrf( + cusolver_handle, m, n, + reinterpret_cast(d_A), + lda, + reinterpret_cast(d_tau), // ← 这里才是 d_tau + d_work, lwork, d_info)); + + int h_info = 0; + cudaErrcheck(cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost)); + if (h_info != 0) { + std::cout << "geqrf (C): info = " << h_info << std::endl; + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); + throw std::runtime_error("geqrf (C): QR factorization failed"); + } + + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); +} + +// --- std::complex --- +static inline void geqrf( + cusolverDnHandle_t& cusolver_handle, + const int m, + const int n, + std::complex* d_A, + const int lda, + std::complex* d_tau +) { + int lwork = 0; + cusolverErrcheck(cusolverDnZgeqrf_bufferSize( + cusolver_handle, m, n, + reinterpret_cast(d_A), + lda, + &lwork + )); + + cuDoubleComplex* d_work = nullptr; + int* d_info = nullptr; + + if (lwork > 0) { + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_work), sizeof(cuDoubleComplex) * lwork)); + } + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_info), sizeof(int))); + + cusolverErrcheck(cusolverDnZgeqrf( + cusolver_handle, m, n, + reinterpret_cast(d_A), + lda, + reinterpret_cast(d_tau), + d_work, lwork, d_info)); + + int h_info = 0; + cudaErrcheck(cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost)); + if (h_info != 0) { + std::cout << "geqrf (Z): info = " << h_info << std::endl; + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); + throw std::runtime_error("geqrf (Z): QR factorization failed"); + } + + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); +} + + +// --- float --- +static inline void orgqr( + cusolverDnHandle_t& cusolver_handle, + const int m, + const int n, + const int k, + float* d_A, + const int lda, + float* d_tau +) { + int lwork = 0; + cusolverErrcheck(cusolverDnSorgqr_bufferSize( + cusolver_handle, m, n, k, d_A, lda, d_tau, &lwork)); + + float* d_work = nullptr; + int* d_info = nullptr; + + if (lwork > 0) { + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_work), sizeof(float) * lwork)); + } + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_info), sizeof(int))); + + cusolverErrcheck(cusolverDnSorgqr( + cusolver_handle, m, n, k, d_A, lda, d_tau, d_work, lwork, d_info)); + + int h_info = 0; + cudaErrcheck(cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost)); + if (h_info != 0) { + std::cout << "orgqr (S): info = " << h_info << " (failure at parameter " << -h_info << ")" << std::endl; + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); + throw std::runtime_error("orgqr (S): failed to generate Q matrix"); + } + + // clean workspace + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); +} + +// --- double --- +static inline void orgqr( + cusolverDnHandle_t& cusolver_handle, + const int m, + const int n, + const int k, + double* d_A, + const int lda, + double* d_tau +) { + int lwork = 0; + cusolverErrcheck(cusolverDnDorgqr_bufferSize( + cusolver_handle, m, n, k, d_A, lda, d_tau, &lwork)); + + double* d_work = nullptr; + int* d_info = nullptr; + + if (lwork > 0) { + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_work), sizeof(double) * lwork)); + } + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_info), sizeof(int))); + + cusolverErrcheck(cusolverDnDorgqr( + cusolver_handle, m, n, k, d_A, lda, d_tau, d_work, lwork, d_info)); + + int h_info = 0; + cudaErrcheck(cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost)); + if (h_info != 0) { + std::cout << "orgqr (D): info = " << h_info << std::endl; + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); + throw std::runtime_error("orgqr (D): failed to generate Q matrix"); + } + + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); +} + +// --- std::complex --- +static inline void orgqr( + cusolverDnHandle_t& cusolver_handle, + const int m, + const int n, + const int k, + std::complex* d_A, + const int lda, + std::complex* d_tau +) { + int lwork = 0; + cusolverErrcheck(cusolverDnCungqr_bufferSize( + cusolver_handle, m, n, k, + reinterpret_cast(d_A), + lda, + reinterpret_cast(d_tau), + &lwork)); + + cuComplex* d_work = nullptr; + int* d_info = nullptr; + + if (lwork > 0) { + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_work), sizeof(cuComplex) * lwork)); + } + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_info), sizeof(int))); + + cusolverErrcheck(cusolverDnCungqr( + cusolver_handle, m, n, k, + reinterpret_cast(d_A), + lda, + reinterpret_cast(d_tau), + d_work, lwork, d_info)); + + int h_info = 0; + cudaErrcheck(cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost)); + if (h_info != 0) { + std::cout << "orgqr (C): info = " << h_info << std::endl; + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); + throw std::runtime_error("orgqr (C): failed to generate Q matrix"); + } + + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); +} + +// --- std::complex --- +static inline void orgqr( + cusolverDnHandle_t& cusolver_handle, + const int m, + const int n, + const int k, + std::complex* d_A, + const int lda, + std::complex* d_tau +) { + int lwork = 0; + cusolverErrcheck(cusolverDnZungqr_bufferSize( + cusolver_handle, m, n, k, + reinterpret_cast(d_A), + lda, + reinterpret_cast(d_tau), + &lwork)); + + cuDoubleComplex* d_work = nullptr; + int* d_info = nullptr; + + if (lwork > 0) { + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_work), sizeof(cuDoubleComplex) * lwork)); + } + cudaErrcheck(cudaMalloc(reinterpret_cast(&d_info), sizeof(int))); + + cusolverErrcheck(cusolverDnZungqr( + cusolver_handle, m, n, k, + reinterpret_cast(d_A), + lda, + reinterpret_cast(d_tau), + d_work, lwork, d_info)); + + int h_info = 0; + cudaErrcheck(cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost)); + if (h_info != 0) { + std::cout << "orgqr (Z): info = " << h_info << std::endl; + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); + throw std::runtime_error("orgqr (Z): failed to generate Q matrix"); + } + + if (d_work) cudaErrcheck(cudaFree(d_work)); + cudaErrcheck(cudaFree(d_info)); +} + + } // namespace cuSolverConnector } // namespace container diff --git a/source/source_base/module_container/base/third_party/lapack.h b/source/source_base/module_container/base/third_party/lapack.h index 9538761734..34881055fd 100644 --- a/source/source_base/module_container/base/third_party/lapack.h +++ b/source/source_base/module_container/base/third_party/lapack.h @@ -194,6 +194,19 @@ void cgetrs_(const char* trans, const int* n, const int* nrhs, void zgetrs_(const char* trans, const int* n, const int* nrhs, const std::complex* A, const int* lda, const int* ipiv, std::complex* B, const int* ldb, int* info); + +// QR factorization +// build R and Householder +void sgeqrf_(const int* m, const int* n, float* A, const int* lda, float* tau, float *work, const int* lwork, int* info); +void dgeqrf_(const int* m, const int* n, double* A, const int* lda, double* tau, double *work, const int* lwork, int* info); +void cgeqrf_(const int* m, const int* n, std::complex* A, const int* lda, std::complex* tau, std::complex *work, const int* lwork, int* info); +void zgeqrf_(const int* m, const int* n, std::complex* A, const int* lda, std::complex* tau, std::complex *work, const int* lwork, int* info); +// make explicit Q +void sorgqr_(const int* m, const int* n, const int* k, float* A, const int* lda, const float* tau, float* work, const int* lwork, int* info); +void dorgqr_(const int* m, const int* n, const int* k, double* A, const int* lda, const double* tau, double* work, const int* lwork, int* info); +void cungqr_(const int* m, const int* n, const int* k, std::complex* A, const int* lda, const std::complex* tau, std::complex *work, const int* lwork, int* info); +void zungqr_(const int* m, const int* n, const int* k, std::complex* A, const int* lda, const std::complex* tau, std::complex *work, const int* lwork, int* info); + } // Class LapackConnector provide the connector to fortran lapack routine. @@ -535,6 +548,49 @@ void getrs(const char& trans, const int n, const int nrhs, std::complex* zgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info); } +// LAPACK routines for QR decomposition +static inline +void geqrf(const int m, const int n, float* A, const int lda, float* tau, float* work, const int lwork, int& info) +{ + sgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info); +} +static inline +void geqrf(const int m, const int n, double* A, const int lda, double* tau, double* work, const int lwork, int& info) +{ + dgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info); +} +static inline +void geqrf(const int m, const int n, std::complex* A, const int lda, std::complex* tau, std::complex* work, const int lwork, int& info) +{ + cgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info); +} +static inline +void geqrf(const int m, const int n, std::complex* A, const int lda, std::complex* tau, std::complex* work, const int lwork, int& info) +{ + zgeqrf_(&m, &n, A, &lda, tau, work, &lwork, &info); +} +// these routines generate the orthogonal matrix Q from the QR decomposition +static inline +void orgqr(const int m, const int n, const int k, float* A, const int lda, const float* tau, float* work, const int lwork, int& info) +{ + sorgqr_(&m, &n, &k, A, &lda, tau, work, &lwork, &info); +} +static inline +void orgqr(const int m, const int n, const int k, double* A, const int lda, const double* tau, double* work, const int lwork, int& info) +{ + dorgqr_(&m, &n, &k, A, &lda, tau, work, &lwork, &info); +} +static inline +void orgqr(const int m, const int n, const int k, std::complex* A, const int lda, const std::complex* tau, std::complex* work, const int lwork, int& info) +{ + cungqr_(&m, &n, &k, A, &lda, tau, work, &lwork, &info); +} +static inline +void orgqr(const int m, const int n, const int k, std::complex* A, const int lda, const std::complex* tau, std::complex* work, const int lwork, int& info) +{ + zungqr_(&m, &n, &k, A, &lda, tau, work, &lwork, &info); +} + } // namespace lapackConnector } // namespace container diff --git a/source/source_hsolver/diago_bpcg.cpp b/source/source_hsolver/diago_bpcg.cpp index 8b9a51b7f8..d4db3d790b 100644 --- a/source/source_hsolver/diago_bpcg.cpp +++ b/source/source_hsolver/diago_bpcg.cpp @@ -112,14 +112,14 @@ void DiagoBPCG::line_minimize( // Finally, the last two! template void DiagoBPCG::orth_cholesky( - ct::Tensor& workspace_in, - ct::Tensor& psi_out, - ct::Tensor& hpsi_out, + ct::Tensor& workspace_in, + ct::Tensor& psi_out, + ct::Tensor& hpsi_out, ct::Tensor& hsub_out) { // gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band) this->pmmcn.multiply(1.0, psi_out.data(), psi_out.data(), 0.0, hsub_out.data()); - + // set hsub matrix to lower format; ct::kernels::set_matrix()( 'L', hsub_out.data(), this->n_band); @@ -209,7 +209,8 @@ void DiagoBPCG::diag_hsub( // gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band) this->pmmcn.multiply(1.0, hpsi_in.data(), psi_in.data(), 0.0, hsub_out.data()); - ct::kernels::lapack_heevd()('V', 'U', hsub_out.data(), this->n_band, eigenvalue_out.data()); + // ct::kernels::lapack_heevd()('V', 'U', hsub_out.data(), this->n_band, eigenvalue_out.data()); + ct::kernels::lapack_heevd()(this->n_band, hsub_out.data(), this->n_band, eigenvalue_out.data()); return; } @@ -235,15 +236,15 @@ void DiagoBPCG::calc_hsub_with_block( // hpsi_out[n_basis, n_band] = psi_out[n_basis, n_band] x hsub_out[n_band, n_band] this->rotate_wf(hsub_out, psi_out, workspace_in); this->rotate_wf(hsub_out, hpsi_out, workspace_in); - + return; } template void DiagoBPCG::calc_hsub_with_block_exit( - ct::Tensor& psi_out, + ct::Tensor& psi_out, ct::Tensor& hpsi_out, - ct::Tensor& hsub_out, + ct::Tensor& hsub_out, ct::Tensor& workspace_in, ct::Tensor& eigenvalue_out) {