Skip to content

Commit 11aaf83

Browse files
committed
Implement weighted random walks
This commit implements weighted biased random walks as in the original Node2vec paper. In particular, it adds a new parameter to the `random_walk` function, i.e., `edge_weight`, which allows passing edge weights to the underlying random walk generation procedure. If edge weights are set, the function normalizes them by the node degree and converts the weights into CDFs over given nodes (needed by the rejection sampling method). The implementation of the new rejection sampling method is based on [1]. [1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69 * Update `random_walk` API * Implement weighted rejection sampling on CPU * Implement weighted random walk for GPU (CUDA) * Compute CDFs using C++/CUDA * Add tests for weighted random walks
1 parent cc4696b commit 11aaf83

File tree

7 files changed

+408
-4
lines changed

7 files changed

+408
-4
lines changed

csrc/cpu/rw_cpu.cpp

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,168 @@ random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
137137

138138
return std::make_tuple(n_out, e_out);
139139
}
140+
141+
142+
void compute_cdf(const int64_t *rowptr, const float_t *edge_weight,
143+
float_t *edge_weight_cdf, int64_t numel) {
144+
/* Convert edge weights to CDF as given in [1]
145+
146+
[1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L148
147+
*/
148+
at::parallel_for(0, numel - 1, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
149+
for(int64_t i = begin; i < end; i++) {
150+
int64_t row_start = rowptr[i], row_end = rowptr[i + 1];
151+
float_t acc = 0.0;
152+
153+
for(int64_t j = row_start; j < row_end; j++) {
154+
acc += edge_weight[j];
155+
edge_weight_cdf[j] = acc;
156+
}
157+
}
158+
});
159+
}
160+
161+
162+
int64_t get_offset(const float_t *edge_weight, int64_t start, int64_t end) {
163+
/*
164+
The implementation given in [1] utilizes the `searchsorted` function in Numpy.
165+
It is also available in PyTorch and its C++ API (via `at::searchsorted()`).
166+
However, the implementation is adopted to the general case where the searched
167+
values can be a multidimensional tensor. In our case, we have a 1D tensor of
168+
edge weights (in form of a Cumulative Distribution Function) and a single
169+
value, whose position we want to compute. To eliminate the overhead introduced
170+
in the PyTorch implementation, one can examine the source code of
171+
`searchsorted` [2] and find that for our case the whole function call can be
172+
reduced to calling the `cus_lower_bound()` function. Unfortunately, we cannot
173+
access it directly (the namespace is not exposed to the public API), but the
174+
implementation is just a simple binary search. The code was copied here and
175+
reduced to the bare minimum.
176+
177+
[1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
178+
[2] https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Bucketization.cpp
179+
*/
180+
float_t value = ((float_t)rand() / RAND_MAX); // [0, 1)
181+
int64_t original_start = start;
182+
183+
while (start < end) {
184+
const int64_t mid = start + ((end - start) >> 1);
185+
const float_t mid_val = edge_weight[mid];
186+
if (!(mid_val >= value)) {
187+
start = mid + 1;
188+
}
189+
else {
190+
end = mid;
191+
}
192+
}
193+
194+
return start - original_start;
195+
}
196+
197+
// See: https://louisabraham.github.io/articles/node2vec-sampling.html
198+
// See also: https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
199+
void rejection_sampling_weighted(const int64_t *rowptr, const int64_t *col,
200+
const float_t *edge_weight_cdf, int64_t *start,
201+
int64_t *n_out, int64_t *e_out,
202+
const int64_t numel, const int64_t walk_length,
203+
const double p, const double q) {
204+
205+
double max_prob = fmax(fmax(1. / p, 1.), 1. / q);
206+
double prob_0 = 1. / p / max_prob;
207+
double prob_1 = 1. / max_prob;
208+
double prob_2 = 1. / q / max_prob;
209+
210+
int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
211+
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
212+
for (auto n = begin; n < end; n++) {
213+
int64_t t = start[n], v, x, e_cur, row_start, row_end;
214+
215+
n_out[n * (walk_length + 1)] = t;
216+
217+
row_start = rowptr[t], row_end = rowptr[t + 1];
218+
219+
if (row_end - row_start == 0) {
220+
e_cur = -1;
221+
v = t;
222+
} else {
223+
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
224+
v = col[e_cur];
225+
}
226+
n_out[n * (walk_length + 1) + 1] = v;
227+
e_out[n * walk_length] = e_cur;
228+
229+
for (auto l = 1; l < walk_length; l++) {
230+
row_start = rowptr[v], row_end = rowptr[v + 1];
231+
232+
if (row_end - row_start == 0) {
233+
e_cur = -1;
234+
x = v;
235+
} else if (row_end - row_start == 1) {
236+
e_cur = row_start;
237+
x = col[e_cur];
238+
} else {
239+
if (p == 1 and q == 1) {
240+
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
241+
x = col[e_cur];
242+
}
243+
else {
244+
while (true) {
245+
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
246+
x = col[e_cur];
247+
248+
auto r = ((double)rand() / (RAND_MAX)); // [0, 1)
249+
250+
if (x == t && r < prob_0)
251+
break;
252+
else if (is_neighbor(rowptr, col, x, t) && r < prob_1)
253+
break;
254+
else if (r < prob_2)
255+
break;
256+
}
257+
}
258+
}
259+
260+
n_out[n * (walk_length + 1) + (l + 1)] = x;
261+
e_out[n * walk_length + l] = e_cur;
262+
t = v;
263+
v = x;
264+
}
265+
}
266+
});
267+
}
268+
269+
270+
std::tuple<torch::Tensor, torch::Tensor>
271+
random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col,
272+
torch::Tensor edge_weight, torch::Tensor start,
273+
int64_t walk_length, double p, double q) {
274+
CHECK_CPU(rowptr);
275+
CHECK_CPU(col);
276+
CHECK_CPU(edge_weight);
277+
CHECK_CPU(start);
278+
279+
CHECK_INPUT(rowptr.dim() == 1);
280+
CHECK_INPUT(col.dim() == 1);
281+
CHECK_INPUT(edge_weight.dim() == 1);
282+
CHECK_INPUT(start.dim() == 1);
283+
284+
auto n_out = torch::empty({start.size(0), walk_length + 1}, start.options());
285+
auto e_out = torch::empty({start.size(0), walk_length}, start.options());
286+
287+
auto rowptr_data = rowptr.data_ptr<int64_t>();
288+
auto col_data = col.data_ptr<int64_t>();
289+
auto edge_weight_data = edge_weight.data_ptr<float_t>();
290+
auto start_data = start.data_ptr<int64_t>();
291+
auto n_out_data = n_out.data_ptr<int64_t>();
292+
auto e_out_data = e_out.data_ptr<int64_t>();
293+
294+
auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options());
295+
auto edge_weight_cdf_data = edge_weight_cdf.data_ptr<float_t>();
296+
297+
compute_cdf(rowptr_data, edge_weight_data, edge_weight_cdf_data, rowptr.numel());
298+
299+
rejection_sampling_weighted(rowptr_data, col_data, edge_weight_cdf_data,
300+
start_data, n_out_data, e_out_data, start.numel(),
301+
walk_length, p, q);
302+
303+
return std::make_tuple(n_out, e_out);
304+
}

csrc/cpu/rw_cpu.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,8 @@
55
std::tuple<torch::Tensor, torch::Tensor>
66
random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
77
int64_t walk_length, double p, double q);
8+
9+
std::tuple<torch::Tensor, torch::Tensor>
10+
random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col,
11+
torch::Tensor edge_weight, torch::Tensor start,
12+
int64_t walk_length, double p, double q);

csrc/cuda/rw_cuda.cu

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,163 @@ random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
150150

151151
return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous());
152152
}
153+
154+
155+
__global__ void cdf_kernel(const int64_t *rowptr, const float_t *edge_weight,
156+
float_t *edge_weight_cdf, int64_t numel) {
157+
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
158+
159+
if (thread_idx < numel - 1) {
160+
int64_t row_start = rowptr[thread_idx], row_end = rowptr[thread_idx + 1];
161+
162+
float_t acc = 0.0;
163+
164+
for(int64_t i = row_start; i < row_end; i++) {
165+
acc += edge_weight[i];
166+
edge_weight_cdf[i] = acc;
167+
}
168+
}
169+
}
170+
171+
__device__ void get_offset(const float_t *edge_weight, int64_t start, int64_t end,
172+
float_t value, int64_t *position_out) {
173+
int64_t original_start = start;
174+
175+
while (start < end) {
176+
const int64_t mid = start + ((end - start) >> 1);
177+
const float_t mid_val = edge_weight[mid];
178+
if (!(mid_val >= value)) {
179+
start = mid + 1;
180+
}
181+
else {
182+
end = mid;
183+
}
184+
}
185+
186+
*position_out = start - original_start;
187+
}
188+
189+
__global__ void
190+
rejection_sampling_weighted_kernel(unsigned int seed, const int64_t *rowptr,
191+
const int64_t *col, const float_t *edge_weight_cdf,
192+
const int64_t *start, int64_t *n_out,
193+
int64_t *e_out, const int64_t walk_length,
194+
const int64_t numel, const double p,
195+
const double q) {
196+
197+
curandState_t state;
198+
curand_init(seed, 0, 0, &state);
199+
200+
double max_prob = fmax(fmax(1. / p, 1.), 1. / q);
201+
double prob_0 = 1. / p / max_prob;
202+
double prob_1 = 1. / max_prob;
203+
double prob_2 = 1. / q / max_prob;
204+
205+
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
206+
207+
if (thread_idx < numel) {
208+
int64_t t = start[thread_idx], v, x, e_cur, row_start, row_end, offset;
209+
210+
n_out[thread_idx] = t;
211+
212+
row_start = rowptr[t], row_end = rowptr[t + 1];
213+
214+
if (row_end - row_start == 0) {
215+
e_cur = -1;
216+
v = t;
217+
} else {
218+
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
219+
e_cur = row_start + offset;
220+
v = col[e_cur];
221+
}
222+
223+
n_out[numel + thread_idx] = v;
224+
e_out[thread_idx] = e_cur;
225+
226+
for (int64_t l = 1; l < walk_length; l++) {
227+
row_start = rowptr[v], row_end = rowptr[v + 1];
228+
229+
if (row_end - row_start == 0) {
230+
e_cur = -1;
231+
x = v;
232+
} else if (row_end - row_start == 1) {
233+
e_cur = row_start;
234+
x = col[e_cur];
235+
} else {
236+
if (p == 1 and q == 1) {
237+
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
238+
e_cur = row_start + offset;
239+
x = col[e_cur];
240+
}
241+
else {
242+
while (true) {
243+
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
244+
e_cur = row_start + offset;
245+
x = col[e_cur];
246+
247+
double r = curand_uniform(&state); // (0, 1]
248+
249+
if (x == t && r < prob_0)
250+
break;
251+
252+
bool is_neighbor = false;
253+
row_start = rowptr[x], row_end = rowptr[x + 1];
254+
for (int64_t i = row_start; i < row_end; i++) {
255+
if (col[i] == t) {
256+
is_neighbor = true;
257+
break;
258+
}
259+
}
260+
261+
if (is_neighbor && r < prob_1)
262+
break;
263+
else if (r < prob_2)
264+
break;
265+
}
266+
}
267+
}
268+
269+
n_out[(l + 1) * numel + thread_idx] = x;
270+
e_out[l * numel + thread_idx] = e_cur;
271+
t = v;
272+
v = x;
273+
}
274+
}
275+
}
276+
277+
278+
std::tuple<torch::Tensor, torch::Tensor>
279+
random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col,
280+
torch::Tensor edge_weight, torch::Tensor start,
281+
int64_t walk_length, double p, double q) {
282+
CHECK_CUDA(rowptr);
283+
CHECK_CUDA(col);
284+
CHECK_CUDA(edge_weight);
285+
CHECK_CUDA(start);
286+
cudaSetDevice(rowptr.get_device());
287+
288+
CHECK_INPUT(rowptr.dim() == 1);
289+
CHECK_INPUT(col.dim() == 1);
290+
CHECK_INPUT(edge_weight.dim() == 1);
291+
CHECK_INPUT(start.dim() == 1);
292+
293+
auto n_out = torch::empty({walk_length + 1, start.size(0)}, start.options());
294+
auto e_out = torch::empty({walk_length, start.size(0)}, start.options());
295+
296+
auto stream = at::cuda::getCurrentCUDAStream();
297+
298+
auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options());
299+
300+
cdf_kernel<<<BLOCKS(rowptr.numel()), THREADS, 0, stream>>>(
301+
rowptr.data_ptr<int64_t>(), edge_weight.data_ptr<float_t>(),
302+
edge_weight_cdf.data_ptr<float_t>(), rowptr.numel());
303+
304+
rejection_sampling_weighted_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
305+
time(NULL), rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
306+
edge_weight_cdf.data_ptr<float_t>(), start.data_ptr<int64_t>(),
307+
n_out.data_ptr<int64_t>(), e_out.data_ptr<int64_t>(),
308+
walk_length, start.numel(), p, q);
309+
310+
return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous());
311+
}
312+

csrc/cuda/rw_cuda.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,8 @@
55
std::tuple<torch::Tensor, torch::Tensor>
66
random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
77
int64_t walk_length, double p, double q);
8+
9+
std::tuple<torch::Tensor, torch::Tensor>
10+
random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col,
11+
torch::Tensor edge_weight, torch::Tensor start,
12+
int64_t walk_length, double p, double q);

csrc/rw.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,21 @@ random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
3333
}
3434
}
3535

36+
CLUSTER_API std::tuple<torch::Tensor, torch::Tensor>
37+
random_walk_weighted(torch::Tensor rowptr, torch::Tensor col,
38+
torch::Tensor edge_weight, torch::Tensor start,
39+
int64_t walk_length, double p, double q) {
40+
if (rowptr.device().is_cuda()) {
41+
#ifdef WITH_CUDA
42+
return random_walk_weighted_cuda(rowptr, col, edge_weight, start, walk_length, p, q);
43+
#else
44+
AT_ERROR("Not compiled with CUDA support");
45+
#endif
46+
} else {
47+
return random_walk_weighted_cpu(rowptr, col, edge_weight, start, walk_length, p, q);
48+
}
49+
}
50+
3651
static auto registry =
37-
torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk);
52+
torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk)
53+
.op("torch_cluster::random_walk_weighted", &random_walk_weighted);

0 commit comments

Comments
 (0)