Skip to content

Commit d5e30a2

Browse files
author
Jian Weng
committed
merge conflict
1 parent 28f9932 commit d5e30a2

File tree

6 files changed

+89
-112
lines changed

6 files changed

+89
-112
lines changed

.ycm_extra_conf.py

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
# For a C project, you would set this to 'c' instead of 'c++'.
4242
'-x',
4343
'c++',
44+
'-x',
45+
'cuda'
46+
'-I',
47+
'/usr/local/cuda/include'
4448
]
4549

4650

tensorcore/poc

694 KB
Binary file not shown.

tensorcore/poc.cu

+69-53
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,89 @@
1-
#include <sys/time.h>
2-
#include <cassert>
1+
#include <assert.h>
32
#include <iostream>
4-
#include <cuda_fp16.h>
53
#include <cuda.h>
64
#include <mma.h>
7-
#include <cuda_runtime_api.h>
5+
#include <cuda_fp16.h>
86

9-
using namespace nvcuda;
7+
#define N 32
8+
#define M 32
9+
#define K 32
1010

11+
using namespace nvcuda;
1112

12-
struct timeval tv0, tv1;
13+
__global__ void foo(half *a, half *b, float *c) {
14+
int block_x = blockIdx.x / 2;
15+
int block_y = blockIdx.x % 2;
1316

14-
void begin_roi() {
15-
gettimeofday(&tv0, nullptr);
16-
}
17+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
18+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
19+
wmma::fragment<wmma::accumulator, 16, 16, 16, float, void> c_frag;
20+
wmma::fill_fragment(c_frag, 0.0f);
1721

18-
#define TV_TO_SEC(tv) (tv.tv_sec * 1000000 + tv.tv_usec)
22+
for (int k = 0; k < M; k += 16) {
23+
wmma::load_matrix_sync(a_frag, a + M * block_x + k, M);
24+
wmma::load_matrix_sync(b_frag, b + K * k + block_y * 16, K);
25+
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
26+
}
1927

20-
void end_roi() {
21-
gettimeofday(&tv1, nullptr);
22-
std::cout << TV_TO_SEC(tv1) - TV_TO_SEC(tv0) << std::endl;
28+
wmma::store_matrix_sync(c + K * block_x * 16 + block_y * 16, c_frag, K, wmma::mem_row_major);
2329
}
2430

25-
extern "C" __global__ void default_function_kernel0( half* __restrict__ a, half* __restrict__ b, float* __restrict__ c) {
26-
27-
for (int x_outer_inner = 0; x_outer_inner < 4; ++x_outer_inner) {
28-
for (int y_outer_inner = 0; y_outer_inner < 4; ++y_outer_inner) {
29-
30-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;
31-
32-
wmma::fill_fragment(c_frag, 0.0f);
33-
34-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
35-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
36-
37-
38-
for (int rv_outer = 0; rv_outer < 256; ++rv_outer) {
39-
40-
half *ptr_a = &a[((((((int)blockIdx.x) * 262144) + (x_outer_inner * 65536)) + (rv_outer * 16)))];
41-
wmma::load_matrix_sync(a_frag, ptr_a, 4096);
42-
half *ptr_b = &b[((((((int)threadIdx.x) * 262144) + (y_outer_inner * 65536)) + (rv_outer * 16)))];
43-
wmma::load_matrix_sync(b_frag, ptr_b, 4096);
44-
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
45-
46-
}
47-
__syncthreads();
48-
49-
float *ptr_c = &c[((((((((int)blockIdx.x) * 262144) + (x_outer_inner * 65536))) + (((int)threadIdx.x) * 64)) + (y_outer_inner * 16)))];
50-
wmma::store_matrix_sync(ptr_c, c_frag, 4096, wmma::mem_row_major);
31+
half a[N * M], b[M * K];
32+
float c[N * K], ref[N * K];
5133

34+
template<typename T>
35+
void print(int n, int m, const T* a) {
36+
for (int i = 0; i < n; ++i) {
37+
for (int j = 0; j < m; ++j) {
38+
if (j) std::cout << " ";
39+
std::cout << a[i * m + j];
5240
}
41+
std::cout << std::endl;
5342
}
5443
}
5544

56-
int main() {
57-
58-
half *a, *b;
59-
float *c;
60-
61-
cudaMalloc(&a, 4096 * 4096 * (sizeof (half)));
62-
cudaMalloc(&b, 4096 * 4096 * (sizeof (half)));
63-
cudaMalloc(&c, 4096 * 4096 * (sizeof (float)));
64-
65-
begin_roi();
66-
for (int i = 0; i < 10; ++i) {
67-
default_function_kernel0<<<64, 64>>>(a, b, c);
45+
template<>
46+
void print(int n, int m, const half* a) {
47+
for (int i = 0; i < n; ++i) {
48+
for (int j = 0; j < m; ++j) {
49+
if (j) std::cout << " ";
50+
std::cout << __half2float(a[i * m + j]);
51+
}
52+
std::cout << std::endl;
6853
}
69-
assert(cudaDeviceSynchronize() == cudaSuccess);
70-
end_roi();
54+
}
7155

56+
int main() {
57+
cudaDeviceProp prop;
58+
assert(cudaSuccess == cudaGetDeviceProperties(&prop, 0));
59+
std::cout << "Warp size is: " << prop.warpSize << std::endl;
60+
61+
for (int i = 0; i < N * M; ++i)
62+
a[i] = __float2half((float )rand() / RAND_MAX * 0.5);
63+
for (int i = 0; i < M * K; ++i)
64+
b[i] = __float2half((float) rand() / RAND_MAX * 0.5);
65+
for (int i = 0; i < N * K; ++i)
66+
c[i] = 0;
67+
for (int i = 0; i < N; ++i)
68+
for (int j = 0; j < K; ++j) {
69+
ref[i * K + j] = 0.0;
70+
for (int k = 0; k < M; ++k)
71+
ref[i * K + j] += __half2float(a[i * M + k]) * __half2float(b[k * K + j]);
72+
}
73+
half *dev_a, *dev_b;
74+
float *dev_c;
75+
cudaMalloc(&dev_a, N * M * sizeof(half));
76+
cudaMalloc(&dev_b, M * K * sizeof(half));
77+
cudaMalloc(&dev_c, N * K * sizeof(float));
78+
cudaMemcpy(dev_a, a, sizeof a, cudaMemcpyHostToDevice);
79+
cudaMemcpy(dev_b, b, sizeof b, cudaMemcpyHostToDevice);
80+
cudaMemcpy(dev_c, c, sizeof c, cudaMemcpyHostToDevice);
81+
foo<<<4, 32>>>(dev_a, dev_b, dev_c);
82+
cudaDeviceSynchronize();
83+
cudaMemcpy(c, dev_c, sizeof c, cudaMemcpyDeviceToHost);
84+
std::cout.precision(1);
85+
std::cout << std::fixed;
86+
//print(N, M, a);
87+
print(N, K, c);
7288
return 0;
7389
}

tensorcore/tensorcore.py

+7-54
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import tvm
22

3-
n = 4096
4-
k = 4096
5-
m = 4096
3+
n = 16
4+
k = 16
5+
m = 16
66

77
a = tvm.placeholder((n, k), dtype='float16', name='a')
88
b = tvm.placeholder((m, k), dtype='float16', name='b')
@@ -13,63 +13,16 @@
1313

1414
sch = tvm.create_schedule(c.op)
1515

16-
c_write = sch.cache_write(c, 'local')
17-
1816
x, y = c.op.axis
1917

2018
xo, xi = sch[c].split(x, 16)
21-
xoo, xoi = sch[c].split(xo, 4)
2219
yo, yi = sch[c].split(y, 16)
23-
yoo, yoi = sch[c].split(yo, 4)
2420

2521
blcx = tvm.thread_axis('blockIdx.x')
2622
thrx = tvm.thread_axis('threadIdx.x')
2723

28-
sch[c].bind(xoo, blcx)
29-
sch[c].bind(yoo, thrx)
30-
31-
red_axis = c_write.op.reduce_axis[0]
32-
ro, ri = sch[c_write].split(red, 16)
33-
sch[c_write].reorder(ro, c_write.op.axis[0], c_write.op.axis[1], ri)
34-
35-
#ax0, ax1 = a_shared.op.axis
36-
#ax1o, ax1i = sch[a_shared].split(ax1, 4)
37-
#fused = sch[a_shared].fuse(ax0, ax1o)
38-
#sch[a_shared].bind(fused, thrx)
39-
40-
sch[c].reorder(xoo, yoo, xoi, yoi, xi, yi)
41-
42-
#a_shared = sch.cache_read(a, 'shared', [c_write])
43-
#sch[a_shared].compute_at(sch[c_write], ro)
44-
45-
sch[c_write].compute_at(sch[c], yoi)
46-
47-
b_shared = sch.cache_read(b, 'shared', [c_write])
48-
sch[b_shared].compute_at(sch[c_write], ri)
49-
50-
def toucher(op):
51-
if isinstance(op, tvm.stmt.For):
52-
print(op.loop_var)
53-
print('a: ', tvm.arith.DomainTouched(op, a, True, True))
54-
print('b: ', tvm.arith.DomainTouched(op, b, True, True))
55-
print('c: ', tvm.arith.DomainTouched(op, c, True, True))
56-
print('c.local: ', tvm.arith.DomainTouched(op, c_write, True, True))
57-
58-
tvm.ir_pass.PostOrderVisit(tvm.build_module.form_body(sch), toucher)
59-
60-
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
61-
print(ir)
62-
63-
module = tvm.build(sch, [a, b, c], target='cuda')
64-
print(module.imported_modules[0].get_source())
65-
66-
module.imported_modules[0].save('gemm.cu')
67-
68-
#import numpy as np
69-
#
70-
#nda = tvm.ndarray.array(np.random.randn(n, k).astype('float16'), tvm.gpu(0))
71-
#ndb = tvm.ndarray.array(np.random.randn(m, k).astype('float16'), tvm.gpu(0))
72-
#ndc = tvm.ndarray.array(np.zeros((n, m), dtype='float32'), tvm.gpu(0))
24+
yio, yii = sch[c].split(yi, 8)
25+
sch[c].reorder(xo, yo, xi, yio, yii, red)
26+
xy = sch[c].fuse(xi, yio)
7327

74-
#timer = module.time_evaluator(module.entry_name, tvm.gpu(0), number=10)
75-
#print(timer(nda, ndb, ndc).mean)
28+
print(tvm.lower(sch, [a, b, c], simple_mode=True))

vnni/Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ poc.exe gemm.exe: %.exe: %.cu
66

77
mkldnn_conv.out mkldnn_gemm.out: %.out: %.cc
88
clang++ $^ -std=c++11 -march=cascadelake -o $@ -O3 \
9-
-I../../mkl-dnn/include -I../../mkl-dnn/build/include \
10-
-L../../mkl-dnn/build/src -lmkldnn -lm -lpthread -lz
9+
-I$(MKLDNN)/include -I../../mkl-dnn/build/include \
10+
-L$(MKLDNN)/build/src -lmkldnn -lm -lpthread -lz
1111

1212
clean:
1313
rm -f *.out *.ll main poc *.exe *.s *.o

vnni/mkldnn_gemm.cc

+7-3
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,21 @@ int main() {
2222

2323
{
2424
begin_roi();
25-
for (int i = 0; i < 100; ++i) {
25+
for (int i = 0; i < 10; ++i) {
2626
mkldnn_status_t status = mkldnn_gemm_s8s8s32('N', 'T', 'F', n, m, k, 1.0,
2727
a, k, 0, b, k, 0, 0.0,
2828
c, m, &co);
2929
assert(status == mkldnn_success);
3030
}
3131
float res = end_roi();
32-
float gvnnis = (float) n * m * k / 64.f / res / 10.0;
32+
float gvnnis = ((float) n * m * k / 64.f * 10.0 / res) / 1000.;
3333
printf("Execution time: %.5f\n", res / 100. / 1000000.);
34-
printf("%.2f GVNNI/s\n", gvnnis / 8 * 7);
34+
printf("%.2f GVNNI/us\n", gvnnis);
3535
}
3636

37+
delete[] a;
38+
delete[] b;
39+
delete[] c;
40+
3741
return 0;
3842
}

0 commit comments

Comments
 (0)