@@ -35,6 +35,11 @@ def test_radius(dtype, device):
35
35
assert to_set (edge_index ) == set ([(0 , 0 ), (0 , 1 ), (0 , 2 ), (0 , 3 ), (1 , 1 ),
36
36
(1 , 2 ), (1 , 5 ), (1 , 6 )])
37
37
38
+ jit = torch .jit .script (radius )
39
+ edge_index = jit (x , y , 2 , max_num_neighbors = 4 )
40
+ assert to_set (edge_index ) == set ([(0 , 0 ), (0 , 1 ), (0 , 2 ), (0 , 3 ), (1 , 1 ),
41
+ (1 , 2 ), (1 , 5 ), (1 , 6 )])
42
+
38
43
edge_index = radius (x , y , 2 , batch_x , batch_y , max_num_neighbors = 4 )
39
44
assert to_set (edge_index ) == set ([(0 , 0 ), (0 , 1 ), (0 , 2 ), (0 , 3 ), (1 , 5 ),
40
45
(1 , 6 )])
@@ -64,12 +69,20 @@ def test_radius_graph(dtype, device):
64
69
assert to_set (edge_index ) == set ([(1 , 0 ), (3 , 0 ), (0 , 1 ), (2 , 1 ), (1 , 2 ),
65
70
(3 , 2 ), (0 , 3 ), (2 , 3 )])
66
71
72
+ jit = torch .jit .script (radius_graph )
73
+ edge_index = jit (x , r = 2.5 , flow = 'source_to_target' )
74
+ assert to_set (edge_index ) == set ([(1 , 0 ), (3 , 0 ), (0 , 1 ), (2 , 1 ), (1 , 2 ),
75
+ (3 , 2 ), (0 , 3 ), (2 , 3 )])
76
+
67
77
68
78
@pytest .mark .parametrize ('dtype,device' , product ([torch .float ], devices ))
69
79
def test_radius_graph_large (dtype , device ):
70
80
x = torch .randn (1000 , 3 , dtype = dtype , device = device )
71
81
72
- edge_index = radius_graph (x , r = 0.5 , flow = 'target_to_source' , loop = True ,
82
+ edge_index = radius_graph (x ,
83
+ r = 0.5 ,
84
+ flow = 'target_to_source' ,
85
+ loop = True ,
73
86
max_num_neighbors = 2000 )
74
87
75
88
tree = scipy .spatial .cKDTree (x .cpu ().numpy ())
0 commit comments