Skip to content

Commit e0eb0c1

Browse files
authored
support for torch.bfloat16 in radius ops (#206)
1 parent ef79a92 commit e0eb0c1

File tree

3 files changed

+13
-10
lines changed

3 files changed

+13
-10
lines changed

csrc/cuda/radius_cuda.cu

+9-7
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,15 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
8181

8282
auto stream = at::cuda::getCurrentCUDAStream();
8383
auto scalar_type = x.scalar_type();
84-
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
85-
radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
86-
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
87-
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
88-
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r * r, x.size(0),
89-
y.size(0), x.size(1), ptr_x.value().numel() - 1, max_num_neighbors);
90-
});
84+
AT_DISPATCH_FLOATING_TYPES_AND2(
85+
at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "_", [&] {
86+
radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
87+
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
88+
ptr_x.value().data_ptr<int64_t>(),
89+
ptr_y.value().data_ptr<int64_t>(), row.data_ptr<int64_t>(),
90+
col.data_ptr<int64_t>(), r * r, x.size(0), y.size(0), x.size(1),
91+
ptr_x.value().numel() - 1, max_num_neighbors);
92+
});
9193

9294
auto mask = row != -1;
9395
return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);

test/test_radius.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import scipy.spatial
55
import torch
66
from torch_cluster import radius, radius_graph
7-
from torch_cluster.testing import devices, grad_dtypes, tensor
7+
from torch_cluster.testing import devices, floating_dtypes, tensor
88

99

1010
def to_set(edge_index):
1111
return set([(i, j) for i, j in edge_index.t().tolist()])
1212

1313

14-
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
14+
@pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices))
1515
def test_radius(dtype, device):
1616
x = tensor([
1717
[-1, -1],
@@ -52,7 +52,7 @@ def test_radius(dtype, device):
5252
(1, 6)])
5353

5454

55-
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
55+
@pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices))
5656
def test_radius_graph(dtype, device):
5757
x = tensor([
5858
[-1, -1],

torch_cluster/testing.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
torch.long
88
]
99
grad_dtypes = [torch.half, torch.float, torch.double]
10+
floating_dtypes = grad_dtypes + [torch.bfloat16]
1011

1112
devices = [torch.device('cpu')]
1213
if torch.cuda.is_available():

0 commit comments

Comments
 (0)