Skip to content

Commit 547955c

Browse files
committed
Update readme, simplify C++
1 parent 2f7587d commit 547955c

File tree

6 files changed

+72
-69
lines changed

6 files changed

+72
-69
lines changed

cutde/coordinators.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ def call_clu_free(obs_pts, tris, slips, nu, fnc):
195195
gpu_results = backend.zeros(n_obs * vec_dim, float_type)
196196

197197
n_obs_blocks = int(np.ceil(n_obs / block_size))
198-
gpu_config = dict(
199-
free_block_size=block_size, float_type=backend.np_to_c_type(float_type)
200-
)
198+
gpu_config = dict(float_type=backend.np_to_c_type(float_type))
201199
module = backend.load_module("free.cu", tmpl_args=gpu_config, tmpl_dir=source_dir)
202200

203201
# Split up the sources into chunks so that we don't completely overwhelm a

cutde/cpp_backend.cpp

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,21 @@ struct XYZ {
1919
SIZE_T z;
2020
};
2121

22-
thread_local XYZ threadIdx;
2322
thread_local XYZ blockIdx;
24-
XYZ blockDim;
2523
XYZ gridDim;
2624

27-
WITHIN_KERNEL SIZE_T get_local_id(unsigned int dim)
28-
{
29-
if(dim == 0) return threadIdx.x;
30-
if(dim == 1) return threadIdx.y;
31-
if(dim == 2) return threadIdx.z;
32-
return 0;
33-
}
25+
WITHIN_KERNEL SIZE_T get_local_id(unsigned int dim) { return 0; }
26+
3427
WITHIN_KERNEL SIZE_T get_group_id(unsigned int dim)
3528
{
3629
if(dim == 0) return blockIdx.x;
3730
if(dim == 1) return blockIdx.y;
3831
if(dim == 2) return blockIdx.z;
3932
return 0;
4033
}
41-
WITHIN_KERNEL SIZE_T get_local_size(unsigned int dim)
42-
{
43-
if(dim == 0) return blockDim.x;
44-
if(dim == 1) return blockDim.y;
45-
if(dim == 2) return blockDim.z;
46-
return 1;
47-
}
34+
35+
WITHIN_KERNEL SIZE_T get_local_size(unsigned int dim) { return 1; }
36+
4837
WITHIN_KERNEL SIZE_T get_num_groups(unsigned int dim)
4938
{
5039
if(dim == 0) return gridDim.x;
@@ -54,11 +43,11 @@ WITHIN_KERNEL SIZE_T get_num_groups(unsigned int dim)
5443
}
5544
WITHIN_KERNEL SIZE_T get_global_size(unsigned int dim)
5645
{
57-
return get_num_groups(dim) * get_local_size(dim);
46+
return get_num_groups(dim);
5847
}
5948
WITHIN_KERNEL SIZE_T get_global_id(unsigned int dim)
6049
{
61-
return get_local_id(dim) + get_group_id(dim) * get_local_size(dim);
50+
return get_group_id(dim);
6251
}
6352

6453
#include <pybind11/pybind11.h>
@@ -101,17 +90,10 @@ decltype(auto) wrapper(R(*fn)(Args...))
10190
std::tuple<SIZE_T,SIZE_T,SIZE_T> block)
10291
{
10392
gridDim = {std::get<0>(grid), std::get<1>(grid), std::get<2>(grid)};
104-
blockDim = {std::get<0>(block), std::get<1>(block), std::get<2>(block)};
10593
blockIdx = {0,0,0};
106-
threadIdx = {0,0,0};
10794

10895
SIZE_T Ngrid = gridDim.x * gridDim.y * gridDim.z;
10996

110-
// block must be (1,1,1)
111-
assert(std::get<0>(block) == 0);
112-
assert(std::get<1>(block) == 0);
113-
assert(std::get<2>(block) == 0);
114-
11597
auto ptr_args = std::make_tuple(conv_arg(args)...);
11698

11799
#pragma omp parallel for

cutde/free.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ void free_${name}(GLOBAL_MEM Real* results,
1313
{
1414
int i = get_global_id(0);
1515
int group_id = get_local_id(0);
16+
int block_size = get_local_size(0);
1617

1718
%for d_obs in range(vec_dim):
1819
Real sum${d_obs} = 0.0;
@@ -27,14 +28,14 @@ void free_${name}(GLOBAL_MEM Real* results,
2728
}
2829

2930
% for d1 in range(3):
30-
LOCAL_MEM Real3 sh_tri${d1}[${free_block_size}];
31+
LOCAL_MEM Real3 sh_tri${d1}[256];
3132
% endfor
32-
LOCAL_MEM Real3 sh_slips[${free_block_size}];
33+
LOCAL_MEM Real3 sh_slips[256];
3334

3435
// NOTE: The blocking scheme set up here seems to be irrelevant because the
3536
// runtime is totally dominated by the floating point operations inside the
3637
// TDE evaluation.
37-
for (int block_start = src_start; block_start < src_end; block_start += ${free_block_size}) {
38+
for (int block_start = src_start; block_start < src_end; block_start += block_size) {
3839
int j = block_start + group_id;
3940
if (j < src_end) {
4041
% for d1 in range(3):
@@ -48,7 +49,7 @@ void free_${name}(GLOBAL_MEM Real* results,
4849
${common.LOCAL_BARRIER()}
4950

5051
if (i < n_obs) {
51-
int block_end = min(src_end, block_start + ${free_block_size});
52+
int block_end = min(src_end, block_start + block_size);
5253
int block_length = block_end - block_start;
5354
for (int block_idx = 0; block_idx < block_length; block_idx++) {
5455
% for d1 in range(3):

0 commit comments

Comments
 (0)