Skip to content

Commit 4c91491

Browse files
committedOct 20, 2022
[add] tree_filter kernel support fp16
1 parent e1c0ea9 commit 4c91491

File tree

3 files changed

+118
-107
lines changed

3 files changed

+118
-107
lines changed
 

‎furnace/kernels/lib_tree_filter/src/bfs/bfs.cu

+7-7
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ bfs_forward(
111111
auto adj_vec_len_tensor = at::zeros({batch_size, vertex_count}, options);
112112
auto parent_index_tensor = at::zeros({batch_size, vertex_count}, options);
113113

114-
int * edge_index = edge_index_tensor.contiguous().data<int>();
115-
int * sorted_index = sorted_index_tensor.contiguous().data<int>();
116-
int * sorted_parent = sorted_parent_tensor.contiguous().data<int>();
117-
int * sorted_child = sorted_child_tensor.contiguous().data<int>();
118-
int * adj_vec = adj_vec_tensor.contiguous().data<int>();
119-
int * adj_vec_len = adj_vec_len_tensor.contiguous().data<int>();
120-
int * parent_index = parent_index_tensor.contiguous().data<int>();
114+
int * edge_index = edge_index_tensor.contiguous().data_ptr<int>();
115+
int * sorted_index = sorted_index_tensor.contiguous().data_ptr<int>();
116+
int * sorted_parent = sorted_parent_tensor.contiguous().data_ptr<int>();
117+
int * sorted_child = sorted_child_tensor.contiguous().data_ptr<int>();
118+
int * adj_vec = adj_vec_tensor.contiguous().data_ptr<int>();
119+
int * adj_vec_len = adj_vec_len_tensor.contiguous().data_ptr<int>();
120+
int * parent_index = parent_index_tensor.contiguous().data_ptr<int>();
121121

122122
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
123123

‎furnace/kernels/lib_tree_filter/src/mst/mst.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ at::Tensor mst_forward(
8989
unsigned edge_count = edge_index_tensor.size(1);
9090

9191
auto edge_index_cpu = edge_index_tensor.cpu();
92-
auto edge_weight_cpu = edge_weight_tensor.cpu();
92+
auto edge_weight_cpu = edge_weight_tensor.cpu().to(at::kFloat);
9393
auto edge_out_cpu = at::empty({batch_size, vertex_count - 1, 2}, edge_index_cpu.options());
9494

95-
int * edge_out = edge_out_cpu.contiguous().data<int>();
96-
int * edge_index = edge_index_cpu.contiguous().data<int>();
97-
float * edge_weight = edge_weight_cpu.contiguous().data<float>();
95+
int * edge_out = edge_out_cpu.contiguous().data_ptr<int>();
96+
int * edge_index = edge_index_cpu.contiguous().data_ptr<int>();
97+
float * edge_weight = edge_weight_cpu.contiguous().data_ptr<float>();
9898

9999
// Loop for batch
100100
std::thread pids[batch_size];

‎furnace/kernels/lib_tree_filter/src/refine/refine.cu

+107-96
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
#define CUDA_NUM_THREADS 64
1515
#define GET_CUDA_CHANNEL(N) ceil(512.0f / N)
1616

17+
template <typename scalar_t>
1718
__global__ void root_leaf_prop_kernel(
18-
float * in_data,
19-
float * out_data,
20-
float * weight,
19+
scalar_t * in_data,
20+
scalar_t * out_data,
21+
scalar_t * weight,
2122
int * sorted_index,
2223
int * sorted_parent_index,
2324
int batch_size,
@@ -53,7 +54,7 @@ __global__ void root_leaf_prop_kernel(
5354
int par_pos = sorted_index[par];
5455
for (int k = channel_idx * vertex_count; k < channel_size * vertex_count;
5556
k += channel_step * vertex_count){
56-
float edge_weight = weight[i];
57+
scalar_t edge_weight = weight[i];
5758
out_data[cur_pos + k] = in_data[i + k] * (1 - edge_weight * edge_weight) +
5859
out_data[par_pos + k] * edge_weight;
5960
__threadfence_block();
@@ -65,10 +66,11 @@ __global__ void root_leaf_prop_kernel(
6566
}
6667
}
6768

69+
template <typename scalar_t>
6870
__global__ void leaf_root_aggr_kernel(
69-
float * in_data,
70-
float * out_data,
71-
float * weight,
71+
scalar_t * in_data,
72+
scalar_t * out_data,
73+
scalar_t * weight,
7274
int * sorted_index,
7375
int * sorted_child_index,
7476
int batch_size,
@@ -113,7 +115,7 @@ __global__ void leaf_root_aggr_kernel(
113115
int cur_pos = sorted_index[i];
114116
for (int k = channel_idx * vertex_count; k < channel_size * vertex_count;
115117
k += channel_step * vertex_count){
116-
float aggr_sum;
118+
scalar_t aggr_sum;
117119
if (in_data != NULL)
118120
aggr_sum = in_data[cur_pos + k];
119121
else
@@ -131,13 +133,14 @@ __global__ void leaf_root_aggr_kernel(
131133
}
132134
}
133135

136+
template <typename scalar_t>
134137
__global__ void root_leaf_grad_kernel(
135-
float * in_data,
136-
float * in_grad,
137-
float * out_data,
138-
float * out_grad,
139-
float * weight,
140-
float * grad,
138+
scalar_t * in_data,
139+
scalar_t * in_grad,
140+
scalar_t * out_data,
141+
scalar_t * out_grad,
142+
scalar_t * weight,
143+
scalar_t * grad,
141144
int * sorted_index,
142145
int * sorted_parent_index,
143146
int batch_size,
@@ -172,14 +175,14 @@ __global__ void root_leaf_grad_kernel(
172175
int par_thread = par % thread_count;
173176
if ((cur == 0) || (node_per_thread[par_thread] >= par)){
174177
for (int k = channel_idx; k < channel_size; k += channel_step){
175-
float edge_weight = weight[i];
178+
scalar_t edge_weight = weight[i];
176179
int data_offset = (k % data_channel_size) * vertex_count;
177180
int grad_offset = (k % grad_channel_size) * vertex_count;
178181
int out_offset = k * vertex_count;
179182

180183
if (cur > 0){
181-
float left = in_grad[cur + grad_offset] * (out_data[par_pos + data_offset] - edge_weight * in_data[cur + data_offset]);
182-
float right = in_data[cur + data_offset] * (out_grad[par + grad_offset] - edge_weight * in_grad[cur + grad_offset]);
184+
scalar_t left = in_grad[cur + grad_offset] * (out_data[par_pos + data_offset] - edge_weight * in_data[cur + data_offset]);
185+
scalar_t right = in_data[cur + data_offset] * (out_grad[par + grad_offset] - edge_weight * in_grad[cur + grad_offset]);
183186

184187
grad[cur + out_offset] = left + right;
185188
out_grad[cur + grad_offset] = in_grad[cur + grad_offset] * (1 - edge_weight * edge_weight) +
@@ -218,31 +221,34 @@ refine_forward(
218221

219222
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
220223

221-
float * feature_in = feature_in_tensor.contiguous().data<float>();
222-
float * edge_weight = edge_weight_tensor.contiguous().data<float>();
223-
int * sorted_index = sorted_index_tensor.contiguous().data<int>();
224-
int * sorted_parent_index = sorted_parent_tensor.contiguous().data<int>();
225-
int * sorted_child_index = sorted_child_tensor.contiguous().data<int>();
226-
float * feature_aggr = feature_aggr_tensor.contiguous().data<float>();
227-
float * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data<float>();
228-
float * weight_sum = weight_sum_tensor.contiguous().data<float>();
229-
float * weight_aggr_sum = weight_sum_up_tensor.contiguous().data<float>();
230-
231-
dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1);
232-
leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
233-
feature_in, feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
234-
root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
235-
feature_aggr_sum, feature_aggr, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size);
236-
237-
dim3 weight_block_dims(CUDA_NUM_THREADS, 1, 1), weight_grid_dims(batch_size, 1, 1);
238-
leaf_root_aggr_kernel <<< weight_grid_dims, weight_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
239-
NULL, weight_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, 1, vertex_size, max_adj_per_node);
240-
root_leaf_prop_kernel <<< weight_grid_dims, weight_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
241-
weight_aggr_sum, weight_sum, edge_weight, sorted_index, sorted_parent_index, batch_size, 1, vertex_size);
242-
224+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(feature_in_tensor.scalar_type(), "refine_forward", [&] {
225+
scalar_t * feature_in = feature_in_tensor.contiguous().data_ptr<scalar_t>();
226+
scalar_t * edge_weight = edge_weight_tensor.contiguous().data_ptr<scalar_t>();
227+
int * sorted_index = sorted_index_tensor.contiguous().data_ptr<int>();
228+
int * sorted_parent_index = sorted_parent_tensor.contiguous().data_ptr<int>();
229+
int * sorted_child_index = sorted_child_tensor.contiguous().data_ptr<int>();
230+
scalar_t * feature_aggr = feature_aggr_tensor.contiguous().data_ptr<scalar_t>();
231+
scalar_t * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data_ptr<scalar_t>();
232+
scalar_t * weight_sum = weight_sum_tensor.contiguous().data_ptr<scalar_t>();
233+
scalar_t * weight_aggr_sum = weight_sum_up_tensor.contiguous().data_ptr<scalar_t>();
234+
235+
dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1);
236+
leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
237+
feature_in, feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
238+
root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
239+
feature_aggr_sum, feature_aggr, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size);
240+
241+
dim3 weight_block_dims(CUDA_NUM_THREADS, 1, 1), weight_grid_dims(batch_size, 1, 1);
242+
leaf_root_aggr_kernel <<< weight_grid_dims, weight_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
243+
static_cast<scalar_t *>(NULL), weight_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, 1, vertex_size, max_adj_per_node);
244+
root_leaf_prop_kernel <<< weight_grid_dims, weight_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
245+
weight_aggr_sum, weight_sum, edge_weight, sorted_index, sorted_parent_index, batch_size, 1, vertex_size);
246+
});
247+
243248
auto feature_out_tensor = feature_aggr_tensor / weight_sum_tensor.unsqueeze(1);
244249
auto result = std::make_tuple(feature_out_tensor, feature_aggr_tensor, feature_aggr_up_tensor,
245-
weight_sum_tensor, weight_sum_up_tensor);
250+
weight_sum_tensor, weight_sum_up_tensor);
251+
246252
return result;
247253
}
248254

@@ -271,28 +277,30 @@ at::Tensor refine_backward_feature(
271277
const int vertex_size = feature_in_tensor.size(2);
272278
const int max_adj_per_node = sorted_child_tensor.size(2);
273279

274-
float * feature_in = feature_in_tensor.contiguous().data<float>();
275-
float * edge_weight = edge_weight_tensor.contiguous().data<float>();
276-
int * sorted_index = sorted_index_tensor.contiguous().data<int>();
277-
int * sorted_parent_index = sorted_parent_tensor.contiguous().data<int>();
278-
int * sorted_child_index = sorted_child_tensor.contiguous().data<int>();
279-
float * feature_aggr = feature_aggr_tensor.contiguous().data<float>();
280-
float * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data<float>();
281-
float * weight_sum = weight_sum_tensor.contiguous().data<float>();
282-
float * weight_aggr_sum = weight_sum_up_tensor.contiguous().data<float>();
283-
float * grad_out = grad_out_tensor.contiguous().data<float>();
284-
float * grad_feature = grad_feature_tensor.contiguous().data<float>();
285-
286-
float * grad_out_norm = grad_out_norm_tensor.contiguous().data<float>();
287-
float * grad_feature_aggr_sum = grad_feature_aggr_sum_tensor.contiguous().data<float>();
288-
289280
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
290281

291-
dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1);
292-
leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
293-
grad_out_norm, grad_feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
294-
root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
295-
grad_feature_aggr_sum, grad_feature, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size);
282+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(feature_in_tensor.scalar_type(), "refine_backward_feature", [&] {
283+
scalar_t * feature_in = feature_in_tensor.contiguous().data_ptr<scalar_t>();
284+
scalar_t * edge_weight = edge_weight_tensor.contiguous().data_ptr<scalar_t>();
285+
int * sorted_index = sorted_index_tensor.contiguous().data_ptr<int>();
286+
int * sorted_parent_index = sorted_parent_tensor.contiguous().data_ptr<int>();
287+
int * sorted_child_index = sorted_child_tensor.contiguous().data_ptr<int>();
288+
scalar_t * feature_aggr = feature_aggr_tensor.contiguous().data_ptr<scalar_t>();
289+
scalar_t * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data_ptr<scalar_t>();
290+
scalar_t * weight_sum = weight_sum_tensor.contiguous().data_ptr<scalar_t>();
291+
scalar_t * weight_aggr_sum = weight_sum_up_tensor.contiguous().data_ptr<scalar_t>();
292+
scalar_t * grad_out = grad_out_tensor.contiguous().data_ptr<scalar_t>();
293+
scalar_t * grad_feature = grad_feature_tensor.contiguous().data_ptr<scalar_t>();
294+
295+
scalar_t * grad_out_norm = grad_out_norm_tensor.contiguous().data_ptr<scalar_t>();
296+
scalar_t * grad_feature_aggr_sum = grad_feature_aggr_sum_tensor.contiguous().data_ptr<scalar_t>();
297+
298+
dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1);
299+
leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
300+
grad_out_norm, grad_feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
301+
root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
302+
grad_feature_aggr_sum, grad_feature, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size);
303+
});
296304

297305
return grad_feature_tensor;
298306
}
@@ -318,49 +326,52 @@ at::Tensor refine_backward_weight(
318326
const int channel_size = feature_in_tensor.size(1);
319327
const int vertex_size = feature_in_tensor.size(2);
320328
const int max_adj_per_node = sorted_child_tensor.size(2);
321-
322-
float * feature_in = feature_in_tensor.contiguous().data<float>();
323-
float * edge_weight = edge_weight_tensor.contiguous().data<float>();
324-
int * sorted_index = sorted_index_tensor.contiguous().data<int>();
325-
int * sorted_parent_index = sorted_parent_tensor.contiguous().data<int>();
326-
int * sorted_child_index = sorted_child_tensor.contiguous().data<int>();
327-
float * feature_out = feature_out_tensor.contiguous().data<float>();
328-
float * feature_aggr = feature_aggr_tensor.contiguous().data<float>();
329-
float * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data<float>();
330-
float * weight_sum = weight_sum_tensor.contiguous().data<float>();
331-
float * weight_aggr_sum = weight_sum_up_tensor.contiguous().data<float>();
332-
float * grad_out = grad_out_tensor.contiguous().data<float>();
333-
float * grad_weight = grad_weight_tensor.contiguous().data<float>();
334-
329+
335330
auto grad_all_channel_tensor = at::zeros_like(feature_in_tensor, options);
336331
auto grad_norm_all_channel_tensor = at::zeros_like(feature_in_tensor, options);
337332
auto grad_out_norm_aggr_sum_tensor = at::zeros_like(feature_in_tensor, options);
338333
auto feature_grad_aggr_sum_tensor = at::zeros_like(feature_in_tensor, options);
339-
340-
float * grad_all_channel = grad_all_channel_tensor.contiguous().data<float>();
341-
float * grad_norm_all_channel = grad_norm_all_channel_tensor.contiguous().data<float>();
342-
float * grad_out_norm_aggr_sum = grad_out_norm_aggr_sum_tensor.contiguous().data<float>();
343-
float * feature_grad_aggr_sum = feature_grad_aggr_sum_tensor.contiguous().data<float>();
344-
345-
auto grad_out_norm_tensor = grad_out_tensor / weight_sum_tensor.unsqueeze(1);
346-
auto feature_grad_tensor = grad_out_norm_tensor * feature_out_tensor;
347-
float * grad_out_norm = grad_out_norm_tensor.contiguous().data<float>();
348-
float * feature_grad = feature_grad_tensor.contiguous().data<float>();
349334

350335
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
351336

352-
dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1);
353-
leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
354-
grad_out_norm, grad_out_norm_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
355-
leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
356-
feature_grad, feature_grad_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
357-
358-
root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
359-
feature_aggr_sum, grad_out_norm_aggr_sum, feature_aggr, grad_out_norm_aggr_sum, edge_weight, grad_all_channel,
360-
sorted_index, sorted_parent_index, batch_size, channel_size, channel_size, vertex_size);
361-
root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
362-
weight_aggr_sum, feature_grad_aggr_sum, weight_sum, feature_grad_aggr_sum, edge_weight, grad_norm_all_channel,
363-
sorted_index, sorted_parent_index, batch_size, 1, channel_size, vertex_size);
337+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(feature_in_tensor.scalar_type(), "refine_backward_weight", [&] {
338+
scalar_t * feature_in = feature_in_tensor.contiguous().data_ptr<scalar_t>();
339+
scalar_t * edge_weight = edge_weight_tensor.contiguous().data_ptr<scalar_t>();
340+
int * sorted_index = sorted_index_tensor.contiguous().data_ptr<int>();
341+
int * sorted_parent_index = sorted_parent_tensor.contiguous().data_ptr<int>();
342+
int * sorted_child_index = sorted_child_tensor.contiguous().data_ptr<int>();
343+
scalar_t * feature_out = feature_out_tensor.contiguous().data_ptr<scalar_t>();
344+
scalar_t * feature_aggr = feature_aggr_tensor.contiguous().data_ptr<scalar_t>();
345+
scalar_t * feature_aggr_sum = feature_aggr_up_tensor.contiguous().data_ptr<scalar_t>();
346+
scalar_t * weight_sum = weight_sum_tensor.contiguous().data_ptr<scalar_t>();
347+
scalar_t * weight_aggr_sum = weight_sum_up_tensor.contiguous().data_ptr<scalar_t>();
348+
scalar_t * grad_out = grad_out_tensor.contiguous().data_ptr<scalar_t>();
349+
scalar_t * grad_weight = grad_weight_tensor.contiguous().data_ptr<scalar_t>();
350+
351+
scalar_t * grad_all_channel = grad_all_channel_tensor.contiguous().data_ptr<scalar_t>();
352+
scalar_t * grad_norm_all_channel = grad_norm_all_channel_tensor.contiguous().data_ptr<scalar_t>();
353+
scalar_t * grad_out_norm_aggr_sum = grad_out_norm_aggr_sum_tensor.contiguous().data_ptr<scalar_t>();
354+
scalar_t * feature_grad_aggr_sum = feature_grad_aggr_sum_tensor.contiguous().data_ptr<scalar_t>();
355+
356+
auto grad_out_norm_tensor = grad_out_tensor / weight_sum_tensor.unsqueeze(1);
357+
auto feature_grad_tensor = grad_out_norm_tensor * feature_out_tensor;
358+
scalar_t * grad_out_norm = grad_out_norm_tensor.contiguous().data_ptr<scalar_t>();
359+
scalar_t * feature_grad = feature_grad_tensor.contiguous().data_ptr<scalar_t>();
360+
361+
dim3 feature_block_dims(CUDA_NUM_THREADS, 1, 1), feature_grid_dims(batch_size, channel_size, 1);
362+
leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
363+
grad_out_norm, grad_out_norm_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
364+
leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
365+
feature_grad, feature_grad_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
366+
367+
root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
368+
feature_aggr_sum, grad_out_norm_aggr_sum, feature_aggr, grad_out_norm_aggr_sum, edge_weight, grad_all_channel,
369+
sorted_index, sorted_parent_index, batch_size, channel_size, channel_size, vertex_size);
370+
root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof(int) * CUDA_NUM_THREADS, stream >>>(
371+
weight_aggr_sum, feature_grad_aggr_sum, weight_sum, feature_grad_aggr_sum, edge_weight, grad_norm_all_channel,
372+
sorted_index, sorted_parent_index, batch_size, 1, channel_size, vertex_size);
373+
374+
});
364375

365376
grad_weight_tensor = (grad_all_channel_tensor - grad_norm_all_channel_tensor).sum(1);
366377

0 commit comments

Comments
 (0)
Please sign in to comment.