14
14
#define CUDA_NUM_THREADS 64
15
15
#define GET_CUDA_CHANNEL (N ) ceil(512 .0f / N)
16
16
17
+ template <typename scalar_t >
17
18
__global__ void root_leaf_prop_kernel (
18
- float * in_data,
19
- float * out_data,
20
- float * weight,
19
+ scalar_t * in_data,
20
+ scalar_t * out_data,
21
+ scalar_t * weight,
21
22
int * sorted_index,
22
23
int * sorted_parent_index,
23
24
int batch_size,
@@ -53,7 +54,7 @@ __global__ void root_leaf_prop_kernel(
53
54
int par_pos = sorted_index[par];
54
55
for (int k = channel_idx * vertex_count; k < channel_size * vertex_count;
55
56
k += channel_step * vertex_count){
56
- float edge_weight = weight[i];
57
+ scalar_t edge_weight = weight[i];
57
58
out_data[cur_pos + k] = in_data[i + k] * (1 - edge_weight * edge_weight) +
58
59
out_data[par_pos + k] * edge_weight;
59
60
__threadfence_block ();
@@ -65,10 +66,11 @@ __global__ void root_leaf_prop_kernel(
65
66
}
66
67
}
67
68
69
+ template <typename scalar_t >
68
70
__global__ void leaf_root_aggr_kernel (
69
- float * in_data,
70
- float * out_data,
71
- float * weight,
71
+ scalar_t * in_data,
72
+ scalar_t * out_data,
73
+ scalar_t * weight,
72
74
int * sorted_index,
73
75
int * sorted_child_index,
74
76
int batch_size,
@@ -113,7 +115,7 @@ __global__ void leaf_root_aggr_kernel(
113
115
int cur_pos = sorted_index[i];
114
116
for (int k = channel_idx * vertex_count; k < channel_size * vertex_count;
115
117
k += channel_step * vertex_count){
116
- float aggr_sum;
118
+ scalar_t aggr_sum;
117
119
if (in_data != NULL )
118
120
aggr_sum = in_data[cur_pos + k];
119
121
else
@@ -131,13 +133,14 @@ __global__ void leaf_root_aggr_kernel(
131
133
}
132
134
}
133
135
136
+ template <typename scalar_t >
134
137
__global__ void root_leaf_grad_kernel (
135
- float * in_data,
136
- float * in_grad,
137
- float * out_data,
138
- float * out_grad,
139
- float * weight,
140
- float * grad,
138
+ scalar_t * in_data,
139
+ scalar_t * in_grad,
140
+ scalar_t * out_data,
141
+ scalar_t * out_grad,
142
+ scalar_t * weight,
143
+ scalar_t * grad,
141
144
int * sorted_index,
142
145
int * sorted_parent_index,
143
146
int batch_size,
@@ -172,14 +175,14 @@ __global__ void root_leaf_grad_kernel(
172
175
int par_thread = par % thread_count;
173
176
if ((cur == 0 ) || (node_per_thread[par_thread] >= par)){
174
177
for (int k = channel_idx; k < channel_size; k += channel_step){
175
- float edge_weight = weight[i];
178
+ scalar_t edge_weight = weight[i];
176
179
int data_offset = (k % data_channel_size) * vertex_count;
177
180
int grad_offset = (k % grad_channel_size) * vertex_count;
178
181
int out_offset = k * vertex_count;
179
182
180
183
if (cur > 0 ){
181
- float left = in_grad[cur + grad_offset] * (out_data[par_pos + data_offset] - edge_weight * in_data[cur + data_offset]);
182
- float right = in_data[cur + data_offset] * (out_grad[par + grad_offset] - edge_weight * in_grad[cur + grad_offset]);
184
+ scalar_t left = in_grad[cur + grad_offset] * (out_data[par_pos + data_offset] - edge_weight * in_data[cur + data_offset]);
185
+ scalar_t right = in_data[cur + data_offset] * (out_grad[par + grad_offset] - edge_weight * in_grad[cur + grad_offset]);
183
186
184
187
grad[cur + out_offset] = left + right;
185
188
out_grad[cur + grad_offset] = in_grad[cur + grad_offset] * (1 - edge_weight * edge_weight) +
@@ -218,31 +221,34 @@ refine_forward(
218
221
219
222
cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
220
223
221
- float * feature_in = feature_in_tensor.contiguous ().data <float >();
222
- float * edge_weight = edge_weight_tensor.contiguous ().data <float >();
223
- int * sorted_index = sorted_index_tensor.contiguous ().data <int >();
224
- int * sorted_parent_index = sorted_parent_tensor.contiguous ().data <int >();
225
- int * sorted_child_index = sorted_child_tensor.contiguous ().data <int >();
226
- float * feature_aggr = feature_aggr_tensor.contiguous ().data <float >();
227
- float * feature_aggr_sum = feature_aggr_up_tensor.contiguous ().data <float >();
228
- float * weight_sum = weight_sum_tensor.contiguous ().data <float >();
229
- float * weight_aggr_sum = weight_sum_up_tensor.contiguous ().data <float >();
230
-
231
- dim3 feature_block_dims (CUDA_NUM_THREADS, 1 , 1 ), feature_grid_dims (batch_size, channel_size, 1 );
232
- leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
233
- feature_in, feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
234
- root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
235
- feature_aggr_sum, feature_aggr, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size);
236
-
237
- dim3 weight_block_dims (CUDA_NUM_THREADS, 1 , 1 ), weight_grid_dims (batch_size, 1 , 1 );
238
- leaf_root_aggr_kernel <<< weight_grid_dims, weight_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
239
- NULL , weight_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, 1 , vertex_size, max_adj_per_node);
240
- root_leaf_prop_kernel <<< weight_grid_dims, weight_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
241
- weight_aggr_sum, weight_sum, edge_weight, sorted_index, sorted_parent_index, batch_size, 1 , vertex_size);
242
-
224
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (feature_in_tensor.scalar_type (), " refine_forward" , [&] {
225
+ scalar_t * feature_in = feature_in_tensor.contiguous ().data_ptr <scalar_t >();
226
+ scalar_t * edge_weight = edge_weight_tensor.contiguous ().data_ptr <scalar_t >();
227
+ int * sorted_index = sorted_index_tensor.contiguous ().data_ptr <int >();
228
+ int * sorted_parent_index = sorted_parent_tensor.contiguous ().data_ptr <int >();
229
+ int * sorted_child_index = sorted_child_tensor.contiguous ().data_ptr <int >();
230
+ scalar_t * feature_aggr = feature_aggr_tensor.contiguous ().data_ptr <scalar_t >();
231
+ scalar_t * feature_aggr_sum = feature_aggr_up_tensor.contiguous ().data_ptr <scalar_t >();
232
+ scalar_t * weight_sum = weight_sum_tensor.contiguous ().data_ptr <scalar_t >();
233
+ scalar_t * weight_aggr_sum = weight_sum_up_tensor.contiguous ().data_ptr <scalar_t >();
234
+
235
+ dim3 feature_block_dims (CUDA_NUM_THREADS, 1 , 1 ), feature_grid_dims (batch_size, channel_size, 1 );
236
+ leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
237
+ feature_in, feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
238
+ root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
239
+ feature_aggr_sum, feature_aggr, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size);
240
+
241
+ dim3 weight_block_dims (CUDA_NUM_THREADS, 1 , 1 ), weight_grid_dims (batch_size, 1 , 1 );
242
+ leaf_root_aggr_kernel <<< weight_grid_dims, weight_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
243
+ static_cast <scalar_t *>(NULL ), weight_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, 1 , vertex_size, max_adj_per_node);
244
+ root_leaf_prop_kernel <<< weight_grid_dims, weight_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
245
+ weight_aggr_sum, weight_sum, edge_weight, sorted_index, sorted_parent_index, batch_size, 1 , vertex_size);
246
+ });
247
+
243
248
auto feature_out_tensor = feature_aggr_tensor / weight_sum_tensor.unsqueeze (1 );
244
249
auto result = std::make_tuple (feature_out_tensor, feature_aggr_tensor, feature_aggr_up_tensor,
245
- weight_sum_tensor, weight_sum_up_tensor);
250
+ weight_sum_tensor, weight_sum_up_tensor);
251
+
246
252
return result;
247
253
}
248
254
@@ -271,28 +277,30 @@ at::Tensor refine_backward_feature(
271
277
const int vertex_size = feature_in_tensor.size (2 );
272
278
const int max_adj_per_node = sorted_child_tensor.size (2 );
273
279
274
- float * feature_in = feature_in_tensor.contiguous ().data <float >();
275
- float * edge_weight = edge_weight_tensor.contiguous ().data <float >();
276
- int * sorted_index = sorted_index_tensor.contiguous ().data <int >();
277
- int * sorted_parent_index = sorted_parent_tensor.contiguous ().data <int >();
278
- int * sorted_child_index = sorted_child_tensor.contiguous ().data <int >();
279
- float * feature_aggr = feature_aggr_tensor.contiguous ().data <float >();
280
- float * feature_aggr_sum = feature_aggr_up_tensor.contiguous ().data <float >();
281
- float * weight_sum = weight_sum_tensor.contiguous ().data <float >();
282
- float * weight_aggr_sum = weight_sum_up_tensor.contiguous ().data <float >();
283
- float * grad_out = grad_out_tensor.contiguous ().data <float >();
284
- float * grad_feature = grad_feature_tensor.contiguous ().data <float >();
285
-
286
- float * grad_out_norm = grad_out_norm_tensor.contiguous ().data <float >();
287
- float * grad_feature_aggr_sum = grad_feature_aggr_sum_tensor.contiguous ().data <float >();
288
-
289
280
cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
290
281
291
- dim3 feature_block_dims (CUDA_NUM_THREADS, 1 , 1 ), feature_grid_dims (batch_size, channel_size, 1 );
292
- leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
293
- grad_out_norm, grad_feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
294
- root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
295
- grad_feature_aggr_sum, grad_feature, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size);
282
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (feature_in_tensor.scalar_type (), " refine_backward_feature" , [&] {
283
+ scalar_t * feature_in = feature_in_tensor.contiguous ().data_ptr <scalar_t >();
284
+ scalar_t * edge_weight = edge_weight_tensor.contiguous ().data_ptr <scalar_t >();
285
+ int * sorted_index = sorted_index_tensor.contiguous ().data_ptr <int >();
286
+ int * sorted_parent_index = sorted_parent_tensor.contiguous ().data_ptr <int >();
287
+ int * sorted_child_index = sorted_child_tensor.contiguous ().data_ptr <int >();
288
+ scalar_t * feature_aggr = feature_aggr_tensor.contiguous ().data_ptr <scalar_t >();
289
+ scalar_t * feature_aggr_sum = feature_aggr_up_tensor.contiguous ().data_ptr <scalar_t >();
290
+ scalar_t * weight_sum = weight_sum_tensor.contiguous ().data_ptr <scalar_t >();
291
+ scalar_t * weight_aggr_sum = weight_sum_up_tensor.contiguous ().data_ptr <scalar_t >();
292
+ scalar_t * grad_out = grad_out_tensor.contiguous ().data_ptr <scalar_t >();
293
+ scalar_t * grad_feature = grad_feature_tensor.contiguous ().data_ptr <scalar_t >();
294
+
295
+ scalar_t * grad_out_norm = grad_out_norm_tensor.contiguous ().data_ptr <scalar_t >();
296
+ scalar_t * grad_feature_aggr_sum = grad_feature_aggr_sum_tensor.contiguous ().data_ptr <scalar_t >();
297
+
298
+ dim3 feature_block_dims (CUDA_NUM_THREADS, 1 , 1 ), feature_grid_dims (batch_size, channel_size, 1 );
299
+ leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
300
+ grad_out_norm, grad_feature_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
301
+ root_leaf_prop_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
302
+ grad_feature_aggr_sum, grad_feature, edge_weight, sorted_index, sorted_parent_index, batch_size, channel_size, vertex_size);
303
+ });
296
304
297
305
return grad_feature_tensor;
298
306
}
@@ -318,49 +326,52 @@ at::Tensor refine_backward_weight(
318
326
const int channel_size = feature_in_tensor.size (1 );
319
327
const int vertex_size = feature_in_tensor.size (2 );
320
328
const int max_adj_per_node = sorted_child_tensor.size (2 );
321
-
322
- float * feature_in = feature_in_tensor.contiguous ().data <float >();
323
- float * edge_weight = edge_weight_tensor.contiguous ().data <float >();
324
- int * sorted_index = sorted_index_tensor.contiguous ().data <int >();
325
- int * sorted_parent_index = sorted_parent_tensor.contiguous ().data <int >();
326
- int * sorted_child_index = sorted_child_tensor.contiguous ().data <int >();
327
- float * feature_out = feature_out_tensor.contiguous ().data <float >();
328
- float * feature_aggr = feature_aggr_tensor.contiguous ().data <float >();
329
- float * feature_aggr_sum = feature_aggr_up_tensor.contiguous ().data <float >();
330
- float * weight_sum = weight_sum_tensor.contiguous ().data <float >();
331
- float * weight_aggr_sum = weight_sum_up_tensor.contiguous ().data <float >();
332
- float * grad_out = grad_out_tensor.contiguous ().data <float >();
333
- float * grad_weight = grad_weight_tensor.contiguous ().data <float >();
334
-
329
+
335
330
auto grad_all_channel_tensor = at::zeros_like (feature_in_tensor, options);
336
331
auto grad_norm_all_channel_tensor = at::zeros_like (feature_in_tensor, options);
337
332
auto grad_out_norm_aggr_sum_tensor = at::zeros_like (feature_in_tensor, options);
338
333
auto feature_grad_aggr_sum_tensor = at::zeros_like (feature_in_tensor, options);
339
-
340
- float * grad_all_channel = grad_all_channel_tensor.contiguous ().data <float >();
341
- float * grad_norm_all_channel = grad_norm_all_channel_tensor.contiguous ().data <float >();
342
- float * grad_out_norm_aggr_sum = grad_out_norm_aggr_sum_tensor.contiguous ().data <float >();
343
- float * feature_grad_aggr_sum = feature_grad_aggr_sum_tensor.contiguous ().data <float >();
344
-
345
- auto grad_out_norm_tensor = grad_out_tensor / weight_sum_tensor.unsqueeze (1 );
346
- auto feature_grad_tensor = grad_out_norm_tensor * feature_out_tensor;
347
- float * grad_out_norm = grad_out_norm_tensor.contiguous ().data <float >();
348
- float * feature_grad = feature_grad_tensor.contiguous ().data <float >();
349
334
350
335
cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
351
336
352
- dim3 feature_block_dims (CUDA_NUM_THREADS, 1 , 1 ), feature_grid_dims (batch_size, channel_size, 1 );
353
- leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
354
- grad_out_norm, grad_out_norm_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
355
- leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
356
- feature_grad, feature_grad_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
357
-
358
- root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
359
- feature_aggr_sum, grad_out_norm_aggr_sum, feature_aggr, grad_out_norm_aggr_sum, edge_weight, grad_all_channel,
360
- sorted_index, sorted_parent_index, batch_size, channel_size, channel_size, vertex_size);
361
- root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
362
- weight_aggr_sum, feature_grad_aggr_sum, weight_sum, feature_grad_aggr_sum, edge_weight, grad_norm_all_channel,
363
- sorted_index, sorted_parent_index, batch_size, 1 , channel_size, vertex_size);
337
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (feature_in_tensor.scalar_type (), " refine_backward_weight" , [&] {
338
+ scalar_t * feature_in = feature_in_tensor.contiguous ().data_ptr <scalar_t >();
339
+ scalar_t * edge_weight = edge_weight_tensor.contiguous ().data_ptr <scalar_t >();
340
+ int * sorted_index = sorted_index_tensor.contiguous ().data_ptr <int >();
341
+ int * sorted_parent_index = sorted_parent_tensor.contiguous ().data_ptr <int >();
342
+ int * sorted_child_index = sorted_child_tensor.contiguous ().data_ptr <int >();
343
+ scalar_t * feature_out = feature_out_tensor.contiguous ().data_ptr <scalar_t >();
344
+ scalar_t * feature_aggr = feature_aggr_tensor.contiguous ().data_ptr <scalar_t >();
345
+ scalar_t * feature_aggr_sum = feature_aggr_up_tensor.contiguous ().data_ptr <scalar_t >();
346
+ scalar_t * weight_sum = weight_sum_tensor.contiguous ().data_ptr <scalar_t >();
347
+ scalar_t * weight_aggr_sum = weight_sum_up_tensor.contiguous ().data_ptr <scalar_t >();
348
+ scalar_t * grad_out = grad_out_tensor.contiguous ().data_ptr <scalar_t >();
349
+ scalar_t * grad_weight = grad_weight_tensor.contiguous ().data_ptr <scalar_t >();
350
+
351
+ scalar_t * grad_all_channel = grad_all_channel_tensor.contiguous ().data_ptr <scalar_t >();
352
+ scalar_t * grad_norm_all_channel = grad_norm_all_channel_tensor.contiguous ().data_ptr <scalar_t >();
353
+ scalar_t * grad_out_norm_aggr_sum = grad_out_norm_aggr_sum_tensor.contiguous ().data_ptr <scalar_t >();
354
+ scalar_t * feature_grad_aggr_sum = feature_grad_aggr_sum_tensor.contiguous ().data_ptr <scalar_t >();
355
+
356
+ auto grad_out_norm_tensor = grad_out_tensor / weight_sum_tensor.unsqueeze (1 );
357
+ auto feature_grad_tensor = grad_out_norm_tensor * feature_out_tensor;
358
+ scalar_t * grad_out_norm = grad_out_norm_tensor.contiguous ().data_ptr <scalar_t >();
359
+ scalar_t * feature_grad = feature_grad_tensor.contiguous ().data_ptr <scalar_t >();
360
+
361
+ dim3 feature_block_dims (CUDA_NUM_THREADS, 1 , 1 ), feature_grid_dims (batch_size, channel_size, 1 );
362
+ leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
363
+ grad_out_norm, grad_out_norm_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
364
+ leaf_root_aggr_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
365
+ feature_grad, feature_grad_aggr_sum, edge_weight, sorted_index, sorted_child_index, batch_size, channel_size, vertex_size, max_adj_per_node);
366
+
367
+ root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
368
+ feature_aggr_sum, grad_out_norm_aggr_sum, feature_aggr, grad_out_norm_aggr_sum, edge_weight, grad_all_channel,
369
+ sorted_index, sorted_parent_index, batch_size, channel_size, channel_size, vertex_size);
370
+ root_leaf_grad_kernel <<< feature_grid_dims, feature_block_dims, sizeof (int ) * CUDA_NUM_THREADS, stream >>> (
371
+ weight_aggr_sum, feature_grad_aggr_sum, weight_sum, feature_grad_aggr_sum, edge_weight, grad_norm_all_channel,
372
+ sorted_index, sorted_parent_index, batch_size, 1 , channel_size, vertex_size);
373
+
374
+ });
364
375
365
376
grad_weight_tensor = (grad_all_channel_tensor - grad_norm_all_channel_tensor).sum (1 );
366
377
0 commit comments