@@ -81,13 +81,15 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
81
81
82
82
auto stream = at::cuda::getCurrentCUDAStream ();
83
83
auto scalar_type = x.scalar_type ();
84
- AT_DISPATCH_FLOATING_TYPES_AND (at::ScalarType::Half, scalar_type, " _" , [&] {
85
- radius_kernel<scalar_t ><<<BLOCKS, THREADS, 0 , stream>>> (
86
- x.data_ptr <scalar_t >(), y.data_ptr <scalar_t >(),
87
- ptr_x.value ().data_ptr <int64_t >(), ptr_y.value ().data_ptr <int64_t >(),
88
- row.data_ptr <int64_t >(), col.data_ptr <int64_t >(), r * r, x.size (0 ),
89
- y.size (0 ), x.size (1 ), ptr_x.value ().numel () - 1 , max_num_neighbors);
90
- });
84
+ AT_DISPATCH_FLOATING_TYPES_AND2 (
85
+ at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, " _" , [&] {
86
+ radius_kernel<scalar_t ><<<BLOCKS, THREADS, 0 , stream>>> (
87
+ x.data_ptr <scalar_t >(), y.data_ptr <scalar_t >(),
88
+ ptr_x.value ().data_ptr <int64_t >(),
89
+ ptr_y.value ().data_ptr <int64_t >(), row.data_ptr <int64_t >(),
90
+ col.data_ptr <int64_t >(), r * r, x.size (0 ), y.size (0 ), x.size (1 ),
91
+ ptr_x.value ().numel () - 1 , max_num_neighbors);
92
+ });
91
93
92
94
auto mask = row != -1 ;
93
95
return torch::stack ({row.masked_select (mask), col.masked_select (mask)}, 0 );
0 commit comments