Skip to content

Commit 5e59143

Browse files
update
1 parent 7218517 commit 5e59143

File tree

10 files changed

+504
-148
lines changed

10 files changed

+504
-148
lines changed

examples/atom/ldgmem_ldsmem_v0.cu

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#include "common.h"
2+
// copy async
3+
4+
// nvcc -arch=sm_90a -std=c++17 -I ../../include/ -lcuda ldgmem_ldsmem_v0.cu -o test
5+
6+
const int SM_LODA_BYTES = 128/8;
7+
8+
template <typename DType, int BLOCKM, int BLOCKN, int NUM_THREADS>
9+
__global__ void naive_matrix_ldsm(DType* source, int M, int N, DType* dummy_out) {
10+
__shared__ DType smem[BLOCKM*BLOCKN];
11+
const int VEC_LEN = SM_LODA_BYTES / sizeof(DType);
12+
const int VEC_REPEAT = BLOCKN / VEC_LEN;
13+
const int THREAD_N = VEC_REPEAT;
14+
const int THREAD_M = NUM_THREADS / THREAD_N;
15+
const int ROW_REPEAT = BLOCKM / THREAD_M;
16+
static_assert(BLOCKN % VEC_LEN == 0);
17+
static_assert(NUM_THREADS % THREAD_N == 0);
18+
static_assert(ROW_REPEAT * THREAD_M == BLOCKM);
19+
20+
int mo = blockIdx.x * BLOCKM;
21+
int mi = threadIdx.x / THREAD_N;
22+
int ni = threadIdx.x % THREAD_N;
23+
int4* ld_source = reinterpret_cast<int4*>(source);
24+
int4* ld_smem = reinterpret_cast<int4*>(smem);
25+
for (int no = 0; no < N; no += BLOCKN) {
26+
for (int row_repeat = 0; row_repeat < ROW_REPEAT; ++row_repeat) {
27+
int m = mo + row_repeat * THREAD_M + mi;
28+
int n = no + ni * VEC_LEN;
29+
int idx = m * N + n;
30+
int sm = row_repeat * THREAD_M + mi;
31+
int sn = ni * VEC_LEN;
32+
int sm_idx = sm * BLOCKN + sn;
33+
ld_smem[sm_idx / VEC_LEN] = ld_source[idx / VEC_LEN];
34+
}
35+
__syncthreads();
36+
for (int row_repeat = 0; row_repeat < ROW_REPEAT; ++row_repeat) {
37+
int m = mo + row_repeat * THREAD_M + mi;
38+
int n = no + ni * VEC_LEN;
39+
int idx = m * N + n;
40+
int sm = row_repeat * THREAD_M + mi;
41+
int sn = ni * VEC_LEN;
42+
int sm_idx = sm * BLOCKN + sn;
43+
for (int i = 0; i < VEC_LEN; ++i) {
44+
dummy_out[idx + i] = smem[sm_idx + i] + DType(1);
45+
}
46+
}
47+
}
48+
}
49+
50+
51+
template<typename DType>
52+
void cpu_dummy(DType* source, DType* dummy_out, int M, int N) {
53+
for (int m = 0; m < M; ++m) {
54+
for (int n = 0; n < N; ++n) {
55+
dummy_out[m * N + n] = (DType)((float)source[m * N + n] + (float)DType(1));
56+
}
57+
}
58+
}
59+
60+
61+
int main(int argc, char** argv) {
62+
const int M = 1024;
63+
const int N = 1024;
64+
using DType = half;
65+
const int BLOCKM = 128;
66+
const int BLOCKN = 128;
67+
const int NUM_THREADS = 128;
68+
std::vector<int> shape{M, N};
69+
auto A = alloc_cpu_tensor<DType>(shape);
70+
random_fill(A, shape);
71+
auto B = alloc_cpu_tensor<DType>(shape);
72+
auto golden = alloc_cpu_tensor<DType>(shape);
73+
74+
GPUTimer gpu_timer;
75+
76+
auto dA = alloc_gpu_tensor<DType>(shape);
77+
auto dB = alloc_gpu_tensor<DType>(shape);
78+
gpu_timer.sync_all();
79+
gpu_timer.tick();
80+
copy_to_gpu_async(A, dA, shape);
81+
dim3 block(NUM_THREADS);
82+
dim3 grid(ceil_div(M, BLOCKM));
83+
naive_matrix_ldsm<DType, BLOCKM, BLOCKN, NUM_THREADS><<<grid, block>>>(dA, M, N, dB);
84+
copy_to_cpu_async(B, dB, shape);
85+
gpu_timer.tick();
86+
gpu_timer.sync_all();
87+
std::cout << "GPU naive done! Use " << gpu_timer.report_last_ms() << " ms.\n";
88+
89+
std::cout << "Calculating golden...\n";
90+
cpu_dummy(A, golden, M, N);
91+
assert_allclose(B, golden, shape, 1e-5, /*dump=*/false);
92+
std::cout << "Correct!\n";
93+
94+
95+
return 0;
96+
}

examples/atom/ldgmem_ldsmem_v1.cu

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#include "common.h"
2+
// split copy async
3+
4+
// nvcc -arch=sm_90a -std=c++17 -I ../../include/ -lcuda ldgmem_ldsmem_v0.cu -o test
5+
6+
const int SM_LODA_BYTES = 128/8;
7+
8+
template <typename DType, int BLOCKM, int BLOCKN, int NUM_THREADS>
9+
__global__ void split_matrix_ldsm(DType* source, int M, int N, DType* dummy_out, int split, int curr_split) {
10+
__shared__ DType smem[BLOCKM*BLOCKN];
11+
const int VEC_LEN = SM_LODA_BYTES / sizeof(DType);
12+
const int VEC_REPEAT = BLOCKN / VEC_LEN;
13+
const int THREAD_N = VEC_REPEAT;
14+
const int THREAD_M = NUM_THREADS / THREAD_N;
15+
const int ROW_REPEAT = BLOCKM / THREAD_M;
16+
static_assert(BLOCKN % VEC_LEN == 0);
17+
static_assert(NUM_THREADS % THREAD_N == 0);
18+
static_assert(ROW_REPEAT * THREAD_M == BLOCKM);
19+
20+
dummy_out += M / split * curr_split * N;
21+
22+
int mo = blockIdx.x * BLOCKM;
23+
int mi = threadIdx.x / THREAD_N;
24+
int ni = threadIdx.x % THREAD_N;
25+
int4* ld_source = reinterpret_cast<int4*>(source);
26+
int4* ld_smem = reinterpret_cast<int4*>(smem);
27+
for (int no = 0; no < N; no += BLOCKN) {
28+
for (int row_repeat = 0; row_repeat < ROW_REPEAT; ++row_repeat) {
29+
int m = mo + row_repeat * THREAD_M + mi;
30+
int n = no + ni * VEC_LEN;
31+
int idx = m * N + n;
32+
int sm = row_repeat * THREAD_M + mi;
33+
int sn = ni * VEC_LEN;
34+
int sm_idx = sm * BLOCKN + sn;
35+
ld_smem[sm_idx / VEC_LEN] = ld_source[idx / VEC_LEN];
36+
}
37+
__syncthreads();
38+
for (int row_repeat = 0; row_repeat < ROW_REPEAT; ++row_repeat) {
39+
int m = mo + row_repeat * THREAD_M + mi;
40+
int n = no + ni * VEC_LEN;
41+
int idx = m * N + n;
42+
int sm = row_repeat * THREAD_M + mi;
43+
int sn = ni * VEC_LEN;
44+
int sm_idx = sm * BLOCKN + sn;
45+
for (int i = 0; i < VEC_LEN; ++i) {
46+
dummy_out[idx + i] = smem[sm_idx + i] + DType(1);
47+
}
48+
}
49+
}
50+
}
51+
52+
53+
template<typename DType>
54+
void cpu_dummy(DType* source, DType* dummy_out, int M, int N) {
55+
for (int m = 0; m < M; ++m) {
56+
for (int n = 0; n < N; ++n) {
57+
dummy_out[m * N + n] = (DType)((float)source[m * N + n] + (float)DType(1));
58+
}
59+
}
60+
}
61+
62+
63+
int main(int argc, char** argv) {
64+
const int M = 1024;
65+
const int N = 1024;
66+
int split = 4;
67+
using DType = half;
68+
const int BLOCKM = 128;
69+
const int BLOCKN = 128;
70+
const int NUM_THREADS = 128;
71+
std::vector<int> shape{M, N};
72+
std::vector<int> epoch_shape{M/split, N};
73+
auto A = alloc_cpu_tensor<DType>(shape);
74+
random_fill(A, shape);
75+
// constant_fill(A, shape, DType(1));
76+
auto B = alloc_cpu_tensor<DType>(shape);
77+
auto golden = alloc_cpu_tensor<DType>(shape);
78+
79+
GPUTimer gpu_timer;
80+
81+
std::vector<DType*> dAs;
82+
for (int i = 0; i < split; ++i) {
83+
dAs.push_back(alloc_gpu_tensor<DType>(epoch_shape));
84+
}
85+
auto dB = alloc_gpu_tensor<DType>(shape);
86+
87+
dim3 block(NUM_THREADS);
88+
dim3 grid(ceil_div(M/split, BLOCKM));
89+
gpu_timer.sync_all();
90+
gpu_timer.tick();
91+
for (int i = 0; i < split; ++i) {
92+
copy_to_gpu_async(A + M/split * i * N, dAs[i], epoch_shape);
93+
split_matrix_ldsm<DType, BLOCKM, BLOCKN, NUM_THREADS><<<grid, block>>>(dAs[i], M, N, dB, split, i);
94+
}
95+
gpu_timer.tick();
96+
gpu_timer.sync_all();
97+
std::cout << "GPU split done! Use " << gpu_timer.report_last_ms() << " ms.\n";
98+
copy_to_cpu_async(B, dB, shape);
99+
100+
101+
std::cout << "Calculating golden...\n";
102+
cpu_dummy(A, golden, M, N);
103+
assert_allclose(B, golden, shape, 1e-5, /*dump=*/false);
104+
std::cout << "Correct!\n";
105+
106+
107+
return 0;
108+
}

include/common.h

+29
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,15 @@ void random_fill(DType* tensor, std::vector<int> shape) {
268268
}
269269
}
270270

271+
template <class DType>
272+
void constant_fill(DType* tensor, std::vector<int> shape, DType value) {
273+
int length =
274+
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
275+
for (int i = 0; i < length; ++i) {
276+
tensor[i] = value;
277+
}
278+
}
279+
271280
template <class DType>
272281
DType* alloc_gpu_tensor(std::vector<int> shape) {
273282
DType* dt;
@@ -297,6 +306,16 @@ void copy_to_gpu(DType* hptr, DType* dptr, std::vector<int> shape) {
297306
cudaMemcpyHostToDevice));
298307
}
299308

309+
template <class DType>
310+
void copy_to_gpu_async(DType* hptr, DType* dptr, std::vector<int> shape,
311+
cudaStream_t stream = 0) {
312+
CUDA_CHECK(cudaMemcpyAsync(
313+
dptr, hptr,
314+
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
315+
sizeof(DType),
316+
cudaMemcpyHostToDevice, stream));
317+
}
318+
300319
template <class DType>
301320
void copy_to_cpu(DType* hptr, DType* dptr, std::vector<int> shape) {
302321
CUDA_CHECK(cudaMemcpy(
@@ -306,6 +325,16 @@ void copy_to_cpu(DType* hptr, DType* dptr, std::vector<int> shape) {
306325
cudaMemcpyDeviceToHost));
307326
}
308327

328+
template <class DType>
329+
void copy_to_cpu_async(DType* hptr, DType* dptr, std::vector<int> shape,
330+
cudaStream_t stream = 0) {
331+
CUDA_CHECK(cudaMemcpyAsync(
332+
hptr, dptr,
333+
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
334+
sizeof(DType),
335+
cudaMemcpyDeviceToHost, stream));
336+
}
337+
309338
template <class DType>
310339
void assert_allclose(DType* res_ptr, DType* golden_ptr, std::vector<int> shape,
311340
float rtol = 1e-5, bool dump = false) {

util/cutlass/test_tile_scheduler.cu

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
2+
#include "cutlass/kernel_hardware_info.hpp"
3+
#include "cutlass/cutlass.h"
4+
#include "cute/tensor.hpp"
5+
6+
#include "common.h"
7+
8+
using namespace cutlass;
9+
using namespace cutlass::gemm::kernel::detail;
10+
using namespace cute;
11+
12+
using Scheduler = PersistentTileSchedulerSm90;
13+
14+
/// nvcc -arch=sm_90a -I ../../include -I /home/jshao/zhengsz/cutlass/include -lcuda -std=c++17 test_tile_scheduler.cu -o test
15+
16+
struct KernelSharedStorage {
17+
18+
};
19+
20+
struct KernelParams {
21+
int M;
22+
int N;
23+
int K;
24+
Scheduler::Params schedule_params;
25+
int* idx;
26+
};
27+
28+
__global__ void test_kernel(KernelParams params) {
29+
Scheduler scheduler(params.schedule_params);
30+
auto tileinfo = scheduler.get_current_work();
31+
if (threadIdx.x == 0 && blockIdx.x == 0) {
32+
int is_n = params.schedule_params.raster_order_ == Scheduler::RasterOrder::AlongN;
33+
printf("log swizzle %d is n %d\n", params.schedule_params.log_swizzle_size_, is_n);
34+
}
35+
if (threadIdx.x == 0) {
36+
printf("block %d maps to linear m %d n %d\n", blockIdx.x, tileinfo.M_idx, tileinfo.N_idx);
37+
}
38+
}
39+
40+
int main() {
41+
const int M = 4096;
42+
const int N = 4096;
43+
const int K = 4096;
44+
dim3 grid(SM_NUMBER, 1, 1);
45+
dim3 block(WARP_GROUP_SIZE * WG_NUMBER, 1, 1);
46+
const int CLUSTER_M = 2;
47+
const int CLUSTER_N = 1;
48+
dim3 cluster(CLUSTER_M, CLUSTER_N, 1);
49+
int smemSizeBytes = sizeof(KernelSharedStorage);
50+
void const *kernel =
51+
(void const *)test_kernel;
52+
53+
auto idx = alloc_cpu_tensor<int>({(int)block.x});
54+
auto g_idx = alloc_gpu_tensor<int>({(int)block.x});
55+
56+
using ShapeMNKL = Shape<int, int, int, int>;
57+
ShapeMNKL shape{M, N, K, 1};
58+
using TileShape = Shape<_128, _128, _64>;
59+
TileShape tile_shape{};
60+
using ClusterShape = Shape<_2, _1, _1>;
61+
ClusterShape cluster_shape{};
62+
KernelHardwareInfo info{};
63+
Scheduler::Arguments args{};
64+
Scheduler::Params schedule_params = Scheduler::to_underlying_arguments(shape, tile_shape, cluster_shape, info, args);
65+
66+
KernelParams params{M, N, K, schedule_params, g_idx};
67+
void *kernel_params[] = {&params};
68+
cudaLaunchConfig_t launch_config;
69+
launch_config.gridDim = {grid.x, grid.y, grid.z};
70+
launch_config.blockDim = {block.x, block.y, block.z};
71+
launch_config.dynamicSmemBytes = smemSizeBytes;
72+
launch_config.stream = nullptr;
73+
74+
cudaLaunchAttribute launch_attribute[1];
75+
launch_attribute[0].id = cudaLaunchAttributeClusterDimension;
76+
launch_attribute[0].val.clusterDim.x = cluster.x;
77+
launch_attribute[0].val.clusterDim.y = cluster.y;
78+
launch_attribute[0].val.clusterDim.z = cluster.z;
79+
80+
launch_config.attrs = launch_attribute;
81+
launch_config.numAttrs = 1;
82+
83+
cudaError_t status =
84+
cudaLaunchKernelExC(&launch_config, kernel, kernel_params);
85+
cudaError_t launch_result = cudaGetLastError();
86+
CUDA_CHECK(launch_result);
87+
88+
copy_to_cpu(idx, g_idx, {(int)block.x});
89+
90+
return 0;
91+
}

0 commit comments

Comments
 (0)