@@ -137,3 +137,168 @@ random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
137
137
138
138
return std::make_tuple (n_out, e_out);
139
139
}
140
+
141
+
142
+ void compute_cdf (const int64_t *rowptr, const float_t *edge_weight,
143
+ float_t *edge_weight_cdf, int64_t numel) {
144
+ /* Convert edge weights to CDF as given in [1]
145
+
146
+ [1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L148
147
+ */
148
+ at::parallel_for (0 , numel - 1 , at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
149
+ for (int64_t i = begin; i < end; i++) {
150
+ int64_t row_start = rowptr[i], row_end = rowptr[i + 1 ];
151
+ float_t acc = 0.0 ;
152
+
153
+ for (int64_t j = row_start; j < row_end; j++) {
154
+ acc += edge_weight[j];
155
+ edge_weight_cdf[j] = acc;
156
+ }
157
+ }
158
+ });
159
+ }
160
+
161
+
162
+ int64_t get_offset (const float_t *edge_weight, int64_t start, int64_t end) {
163
+ /*
164
+ The implementation given in [1] utilizes the `searchsorted` function in Numpy.
165
+ It is also available in PyTorch and its C++ API (via `at::searchsorted()`).
166
+ However, the implementation is adopted to the general case where the searched
167
+ values can be a multidimensional tensor. In our case, we have a 1D tensor of
168
+ edge weights (in form of a Cumulative Distribution Function) and a single
169
+ value, whose position we want to compute. To eliminate the overhead introduced
170
+ in the PyTorch implementation, one can examine the source code of
171
+ `searchsorted` [2] and find that for our case the whole function call can be
172
+ reduced to calling the `cus_lower_bound()` function. Unfortunately, we cannot
173
+ access it directly (the namespace is not exposed to the public API), but the
174
+ implementation is just a simple binary search. The code was copied here and
175
+ reduced to the bare minimum.
176
+
177
+ [1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
178
+ [2] https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Bucketization.cpp
179
+ */
180
+ float_t value = ((float_t )rand () / RAND_MAX); // [0, 1)
181
+ int64_t original_start = start;
182
+
183
+ while (start < end) {
184
+ const int64_t mid = start + ((end - start) >> 1 );
185
+ const float_t mid_val = edge_weight[mid];
186
+ if (!(mid_val >= value)) {
187
+ start = mid + 1 ;
188
+ }
189
+ else {
190
+ end = mid;
191
+ }
192
+ }
193
+
194
+ return start - original_start;
195
+ }
196
+
197
+ // See: https://louisabraham.github.io/articles/node2vec-sampling.html
198
+ // See also: https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
199
+ void rejection_sampling_weighted (const int64_t *rowptr, const int64_t *col,
200
+ const float_t *edge_weight_cdf, int64_t *start,
201
+ int64_t *n_out, int64_t *e_out,
202
+ const int64_t numel, const int64_t walk_length,
203
+ const double p, const double q) {
204
+
205
+ double max_prob = fmax (fmax (1 . / p, 1 .), 1 . / q);
206
+ double prob_0 = 1 . / p / max_prob;
207
+ double prob_1 = 1 . / max_prob;
208
+ double prob_2 = 1 . / q / max_prob;
209
+
210
+ int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
211
+ at::parallel_for (0 , numel, grain_size, [&](int64_t begin, int64_t end) {
212
+ for (auto n = begin; n < end; n++) {
213
+ int64_t t = start[n], v, x, e_cur, row_start, row_end;
214
+
215
+ n_out[n * (walk_length + 1 )] = t;
216
+
217
+ row_start = rowptr[t], row_end = rowptr[t + 1 ];
218
+
219
+ if (row_end - row_start == 0 ) {
220
+ e_cur = -1 ;
221
+ v = t;
222
+ } else {
223
+ e_cur = row_start + get_offset (edge_weight_cdf, row_start, row_end);
224
+ v = col[e_cur];
225
+ }
226
+ n_out[n * (walk_length + 1 ) + 1 ] = v;
227
+ e_out[n * walk_length] = e_cur;
228
+
229
+ for (auto l = 1 ; l < walk_length; l++) {
230
+ row_start = rowptr[v], row_end = rowptr[v + 1 ];
231
+
232
+ if (row_end - row_start == 0 ) {
233
+ e_cur = -1 ;
234
+ x = v;
235
+ } else if (row_end - row_start == 1 ) {
236
+ e_cur = row_start;
237
+ x = col[e_cur];
238
+ } else {
239
+ if (p == 1 and q == 1 ) {
240
+ e_cur = row_start + get_offset (edge_weight_cdf, row_start, row_end);
241
+ x = col[e_cur];
242
+ }
243
+ else {
244
+ while (true ) {
245
+ e_cur = row_start + get_offset (edge_weight_cdf, row_start, row_end);
246
+ x = col[e_cur];
247
+
248
+ auto r = ((double )rand () / (RAND_MAX)); // [0, 1)
249
+
250
+ if (x == t && r < prob_0)
251
+ break ;
252
+ else if (is_neighbor (rowptr, col, x, t) && r < prob_1)
253
+ break ;
254
+ else if (r < prob_2)
255
+ break ;
256
+ }
257
+ }
258
+ }
259
+
260
+ n_out[n * (walk_length + 1 ) + (l + 1 )] = x;
261
+ e_out[n * walk_length + l] = e_cur;
262
+ t = v;
263
+ v = x;
264
+ }
265
+ }
266
+ });
267
+ }
268
+
269
+
270
+ std::tuple<torch::Tensor, torch::Tensor>
271
+ random_walk_weighted_cpu (torch::Tensor rowptr, torch::Tensor col,
272
+ torch::Tensor edge_weight, torch::Tensor start,
273
+ int64_t walk_length, double p, double q) {
274
+ CHECK_CPU (rowptr);
275
+ CHECK_CPU (col);
276
+ CHECK_CPU (edge_weight);
277
+ CHECK_CPU (start);
278
+
279
+ CHECK_INPUT (rowptr.dim () == 1 );
280
+ CHECK_INPUT (col.dim () == 1 );
281
+ CHECK_INPUT (edge_weight.dim () == 1 );
282
+ CHECK_INPUT (start.dim () == 1 );
283
+
284
+ auto n_out = torch::empty ({start.size (0 ), walk_length + 1 }, start.options ());
285
+ auto e_out = torch::empty ({start.size (0 ), walk_length}, start.options ());
286
+
287
+ auto rowptr_data = rowptr.data_ptr <int64_t >();
288
+ auto col_data = col.data_ptr <int64_t >();
289
+ auto edge_weight_data = edge_weight.data_ptr <float_t >();
290
+ auto start_data = start.data_ptr <int64_t >();
291
+ auto n_out_data = n_out.data_ptr <int64_t >();
292
+ auto e_out_data = e_out.data_ptr <int64_t >();
293
+
294
+ auto edge_weight_cdf = torch::empty ({edge_weight.size (0 )}, edge_weight.options ());
295
+ auto edge_weight_cdf_data = edge_weight_cdf.data_ptr <float_t >();
296
+
297
+ compute_cdf (rowptr_data, edge_weight_data, edge_weight_cdf_data, rowptr.numel ());
298
+
299
+ rejection_sampling_weighted (rowptr_data, col_data, edge_weight_cdf_data,
300
+ start_data, n_out_data, e_out_data, start.numel (),
301
+ walk_length, p, q);
302
+
303
+ return std::make_tuple (n_out, e_out);
304
+ }
0 commit comments