Skip to content

Commit 4167001

Browse files
committed
checking sgemm and dgemm interface
1 parent 7b36531 commit 4167001

File tree

3 files changed

+39
-0
lines changed

3 files changed

+39
-0
lines changed

cutlass.cu

+4
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
#include <cutlass/wmma_matrix.h>
33
#include <cutlass/gemm/gemm.h>
44
#include <cutlass/gemm/wmma_gemm_traits.h>
5+
#include "cutlass/gemm/sgemm_traits.h"
6+
#include "cutlass/gemm/dgemm_traits.h"
57
#include <gemm-test/gemm_testbed.h>
68
#include <gemm-test/gemm.h>
79

810
int main(int argc, char* argv[]) {
911

1012
#include <gemm-test/wmma_tests.h>
13+
#include <gemm-test/sgemm_tests.h>
14+
#include <gemm-test/dgemm_tests.h>
1115
}
1216

1317

gemm-test/dgemm_tests.h

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
2+
cutlass::MatrixLayout::kRowMajor,
3+
cutlass::Shape<8, 32, 64> > DGemmTraits1;
4+
run_gemm<DGemmTraits1>(64, 32, 8);
5+
6+
7+
8+
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
9+
cutlass::MatrixLayout::kRowMajor,
10+
cutlass::Shape<8, 32, 64> > DGemmTraits2;
11+
run_gemm<DGemmTraits2>(256, 128, 64);
12+
13+
14+
15+
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
16+
cutlass::MatrixLayout::kRowMajor,
17+
cutlass::Shape<8, 64, 64> > DGemmTraits3;
18+
run_gemm<DGemmTraits3>(64, 64, 8);
19+

gemm-test/sgemm_tests.h

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
2+
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
3+
SgemmTraits1;
4+
run_gemm<SgemmTraits1>(1024, 512, 8);
5+
6+
7+
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
8+
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
9+
SgemmTraits2;
10+
run_gemm<SgemmTraits2>(128, 81, 1);
11+
12+
13+
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
14+
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
15+
SgemmTraits3;
16+
run_gemm<SgemmTraits3>(128, 112, 8);

0 commit comments

Comments
 (0)