Skip to content

Commit 29cd22b

Browse files
authored
Move torch.jit.script check to test (#194)
* update * update * update
1 parent 89b74f0 commit 29cd22b

16 files changed

+54
-22
lines changed

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
cmake_minimum_required(VERSION 3.0)
22
project(torchcluster)
33
set(CMAKE_CXX_STANDARD 14)
4-
set(TORCHCLUSTER_VERSION 1.6.2)
4+
set(TORCHCLUSTER_VERSION 1.6.3)
55

66
option(WITH_CUDA "Enable CUDA support" OFF)
77
option(WITH_PYTHON "Link to Python when building" ON)

conda/pytorch-cluster/meta.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package:
22
name: pytorch-cluster
3-
version: 1.6.2
3+
version: 1.6.3
44

55
source:
66
path: ../..

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
1212
CUDAExtension)
1313

14-
__version__ = '1.6.2'
14+
__version__ = '1.6.3'
1515
URL = 'https://github.com/rusty1s/pytorch_cluster'
1616

1717
WITH_CUDA = False

test/test_graclus.py

+4
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device):
5050

5151
cluster = graclus_cluster(row, col, weight)
5252
assert_correct(row, col, cluster)
53+
54+
jit = torch.jit.script(graclus_cluster)
55+
cluster = jit(row, col, weight)
56+
assert_correct(row, col, cluster)

test/test_grid.py

+3
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device):
3838

3939
cluster = grid_cluster(pos, size, start, end)
4040
assert cluster.tolist() == test['cluster']
41+
42+
jit = torch.jit.script(grid_cluster)
43+
assert torch.equal(jit(pos, size, start, end), cluster)

test/test_knn.py

+9
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ def test_knn(dtype, device):
3434
edge_index = knn(x, y, 2)
3535
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
3636

37+
jit = torch.jit.script(knn)
38+
edge_index = jit(x, y, 2)
39+
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
40+
3741
edge_index = knn(x, y, 2, batch_x, batch_y)
3842
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
3943

@@ -65,6 +69,11 @@ def test_knn_graph(dtype, device):
6569
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
6670
(3, 2), (0, 3), (2, 3)])
6771

72+
jit = torch.jit.script(knn_graph)
73+
edge_index = jit(x, k=2, flow='source_to_target')
74+
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
75+
(3, 2), (0, 3), (2, 3)])
76+
6877

6978
@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
7079
def test_knn_graph_large(dtype, device):

test/test_radius.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ def test_radius(dtype, device):
3535
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
3636
(1, 2), (1, 5), (1, 6)])
3737

38+
jit = torch.jit.script(radius)
39+
edge_index = jit(x, y, 2, max_num_neighbors=4)
40+
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
41+
(1, 2), (1, 5), (1, 6)])
42+
3843
edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
3944
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
4045
(1, 6)])
@@ -64,12 +69,20 @@ def test_radius_graph(dtype, device):
6469
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
6570
(3, 2), (0, 3), (2, 3)])
6671

72+
jit = torch.jit.script(radius_graph)
73+
edge_index = jit(x, r=2.5, flow='source_to_target')
74+
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
75+
(3, 2), (0, 3), (2, 3)])
76+
6777

6878
@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
6979
def test_radius_graph_large(dtype, device):
7080
x = torch.randn(1000, 3, dtype=dtype, device=device)
7181

72-
edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
82+
edge_index = radius_graph(x,
83+
r=0.5,
84+
flow='target_to_source',
85+
loop=True,
7386
max_num_neighbors=2000)
7487

7588
tree = scipy.spatial.cKDTree(x.cpu().numpy())

test/test_rw.py

+3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def test_rw_small(device):
3131
out = random_walk(row, col, start, walk_length, num_nodes=3)
3232
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]
3333

34+
jit = torch.jit.script(random_walk)
35+
assert torch.equal(jit(row, col, start, walk_length, num_nodes=3), out)
36+
3437

3538
@pytest.mark.parametrize('device', devices)
3639
def test_rw_large_with_edge_indices(device):

torch_cluster/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
__version__ = '1.6.2'
6+
__version__ = '1.6.3'
77

88
for library in [
99
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest',

torch_cluster/graclus.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import torch
44

55

6-
@torch.jit.script
7-
def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
8-
weight: Optional[torch.Tensor] = None,
9-
num_nodes: Optional[int] = None) -> torch.Tensor:
6+
def graclus_cluster(
7+
row: torch.Tensor,
8+
col: torch.Tensor,
9+
weight: Optional[torch.Tensor] = None,
10+
num_nodes: Optional[int] = None,
11+
) -> torch.Tensor:
1012
"""A greedy clustering algorithm of picking an unmarked vertex and matching
1113
it with one its unmarked neighbors (that maximizes its edge weight).
1214

torch_cluster/grid.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import torch
44

55

6-
@torch.jit.script
7-
def grid_cluster(pos: torch.Tensor, size: torch.Tensor,
8-
start: Optional[torch.Tensor] = None,
9-
end: Optional[torch.Tensor] = None) -> torch.Tensor:
6+
def grid_cluster(
7+
pos: torch.Tensor,
8+
size: torch.Tensor,
9+
start: Optional[torch.Tensor] = None,
10+
end: Optional[torch.Tensor] = None,
11+
) -> torch.Tensor:
1012
"""A clustering algorithm, which overlays a regular grid of user-defined
1113
size over a point cloud and clusters all points within a voxel.
1214

torch_cluster/knn.py

-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44

55

6-
@torch.jit.script
76
def knn(
87
x: torch.Tensor,
98
y: torch.Tensor,
@@ -83,7 +82,6 @@ def knn(
8382
num_workers)
8483

8584

86-
@torch.jit.script
8785
def knn_graph(
8886
x: torch.Tensor,
8987
k: int,

torch_cluster/radius.py

-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44

55

6-
@torch.jit.script
76
def radius(
87
x: torch.Tensor,
98
y: torch.Tensor,
@@ -84,7 +83,6 @@ def radius(
8483
max_num_neighbors, num_workers)
8584

8685

87-
@torch.jit.script
8886
def radius_graph(
8987
x: torch.Tensor,
9088
r: float,

torch_cluster/rw.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from torch import Tensor
55

66

7-
@torch.jit.script
87
def random_walk(
98
row: Tensor,
109
col: Tensor,
@@ -55,8 +54,7 @@ def random_walk(
5554
torch.cumsum(deg, 0, out=rowptr[1:])
5655

5756
node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
58-
rowptr, col, start, walk_length, p, q,
59-
)
57+
rowptr, col, start, walk_length, p, q)
6058

6159
if return_edge_indices:
6260
return node_seq, edge_seq

torch_cluster/sampler.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22

33

4-
@torch.jit.script
54
def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float):
65
assert not start.is_cuda
76

torch_cluster/typing.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
import torch
22

3-
WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')
3+
try:
4+
WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')
5+
except Exception:
6+
WITH_PTR_LIST = False

0 commit comments

Comments
 (0)