Skip to content

Commit bf9d429

Browse files
authored
Add KS tests for weighted sampling; use A-ExpJ alg with log-keys (#1530)
- Extra testing for weighted sampling - Fix IndexedRandom::choose_multiple_weighted with very small keys - Use A-ExpJ algorithm with BinaryHeap for better performance with large length / amount
1 parent 0ff946c commit bf9d429

File tree

4 files changed

+277
-30
lines changed

4 files changed

+277
-30
lines changed

distr_test/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ edition = "2021"
55
publish = false
66

77
[dev-dependencies]
8-
rand_distr = { path = "../rand_distr", version = "=0.5.0-alpha.1", default-features = false }
8+
rand_distr = { path = "../rand_distr", version = "=0.5.0-alpha.1", default-features = false, features = ["alloc"] }
99
rand = { path = "..", version = "=0.9.0-alpha.1", features = ["small_rng"] }
1010
num-traits = "0.2.19"
1111
# Special functions for testing distributions

distr_test/tests/weighted.rs

+235
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
// Copyright 2024 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
mod ks;
10+
use ks::test_discrete;
11+
use rand::distr::{Distribution, WeightedIndex};
12+
use rand::seq::{IndexedRandom, IteratorRandom};
13+
use rand_distr::{WeightedAliasIndex, WeightedTreeIndex};
14+
15+
/// Takes the unnormalized pdf and creates the cdf of a discrete distribution
16+
fn make_cdf(num: usize, f: impl Fn(i64) -> f64) -> impl Fn(i64) -> f64 {
17+
let mut cdf = Vec::with_capacity(num);
18+
let mut ac = 0.0;
19+
for i in 0..num {
20+
ac += f(i as i64);
21+
cdf.push(ac);
22+
}
23+
24+
let frac = 1.0 / ac;
25+
for x in &mut cdf {
26+
*x *= frac;
27+
}
28+
29+
move |i| {
30+
if i < 0 {
31+
0.0
32+
} else {
33+
cdf[i as usize]
34+
}
35+
}
36+
}
37+
38+
#[test]
39+
fn weighted_index() {
40+
fn test_weights(num: usize, weight: impl Fn(i64) -> f64) {
41+
let distr = WeightedIndex::new((0..num).map(|i| weight(i as i64))).unwrap();
42+
test_discrete(0, distr, make_cdf(num, weight));
43+
}
44+
45+
test_weights(100, |_| 1.0);
46+
test_weights(100, |i| ((i + 1) as f64).ln());
47+
test_weights(100, |i| i as f64);
48+
test_weights(100, |i| (i as f64).powi(3));
49+
test_weights(100, |i| 1.0 / ((i + 1) as f64));
50+
}
51+
52+
#[test]
53+
fn weighted_alias_index() {
54+
fn test_weights(num: usize, weight: impl Fn(i64) -> f64) {
55+
let weights = (0..num).map(|i| weight(i as i64)).collect();
56+
let distr = WeightedAliasIndex::new(weights).unwrap();
57+
test_discrete(0, distr, make_cdf(num, weight));
58+
}
59+
60+
test_weights(100, |_| 1.0);
61+
test_weights(100, |i| ((i + 1) as f64).ln());
62+
test_weights(100, |i| i as f64);
63+
test_weights(100, |i| (i as f64).powi(3));
64+
test_weights(100, |i| 1.0 / ((i + 1) as f64));
65+
}
66+
67+
#[test]
68+
fn weighted_tree_index() {
69+
fn test_weights(num: usize, weight: impl Fn(i64) -> f64) {
70+
let distr = WeightedTreeIndex::new((0..num).map(|i| weight(i as i64))).unwrap();
71+
test_discrete(0, distr, make_cdf(num, weight));
72+
}
73+
74+
test_weights(100, |_| 1.0);
75+
test_weights(100, |i| ((i + 1) as f64).ln());
76+
test_weights(100, |i| i as f64);
77+
test_weights(100, |i| (i as f64).powi(3));
78+
test_weights(100, |i| 1.0 / ((i + 1) as f64));
79+
}
80+
81+
#[test]
82+
fn choose_weighted_indexed() {
83+
struct Adapter<F: Fn(i64) -> f64>(Vec<i64>, F);
84+
impl<F: Fn(i64) -> f64> Distribution<i64> for Adapter<F> {
85+
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> i64 {
86+
*IndexedRandom::choose_weighted(&self.0[..], rng, |i| (self.1)(*i)).unwrap()
87+
}
88+
}
89+
90+
fn test_weights(num: usize, weight: impl Fn(i64) -> f64) {
91+
let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight);
92+
test_discrete(0, distr, make_cdf(num, &weight));
93+
}
94+
95+
test_weights(100, |_| 1.0);
96+
test_weights(100, |i| ((i + 1) as f64).ln());
97+
test_weights(100, |i| i as f64);
98+
test_weights(100, |i| (i as f64).powi(3));
99+
test_weights(100, |i| 1.0 / ((i + 1) as f64));
100+
}
101+
102+
#[test]
103+
fn choose_one_weighted_indexed() {
104+
struct Adapter<F: Fn(i64) -> f64>(Vec<i64>, F);
105+
impl<F: Fn(i64) -> f64> Distribution<i64> for Adapter<F> {
106+
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> i64 {
107+
*IndexedRandom::choose_multiple_weighted(&self.0[..], rng, 1, |i| (self.1)(*i))
108+
.unwrap()
109+
.next()
110+
.unwrap()
111+
}
112+
}
113+
114+
fn test_weights(num: usize, weight: impl Fn(i64) -> f64) {
115+
let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight);
116+
test_discrete(0, distr, make_cdf(num, &weight));
117+
}
118+
119+
test_weights(100, |_| 1.0);
120+
test_weights(100, |i| ((i + 1) as f64).ln());
121+
test_weights(100, |i| i as f64);
122+
test_weights(100, |i| (i as f64).powi(3));
123+
test_weights(100, |i| 1.0 / ((i + 1) as f64));
124+
}
125+
126+
#[test]
127+
fn choose_two_weighted_indexed() {
128+
struct Adapter<F: Fn(i64) -> f64>(Vec<i64>, F);
129+
impl<F: Fn(i64) -> f64> Distribution<i64> for Adapter<F> {
130+
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> i64 {
131+
let mut iter =
132+
IndexedRandom::choose_multiple_weighted(&self.0[..], rng, 2, |i| (self.1)(*i))
133+
.unwrap();
134+
let mut a = *iter.next().unwrap();
135+
let mut b = *iter.next().unwrap();
136+
assert!(iter.next().is_none());
137+
if b < a {
138+
std::mem::swap(&mut a, &mut b);
139+
}
140+
a * self.0.len() as i64 + b
141+
}
142+
}
143+
144+
fn test_weights(num: usize, weight: impl Fn(i64) -> f64) {
145+
let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight);
146+
147+
let pmf1 = (0..num).map(|i| weight(i as i64)).collect::<Vec<f64>>();
148+
let sum: f64 = pmf1.iter().sum();
149+
let frac = 1.0 / sum;
150+
151+
let mut ac = 0.0;
152+
let mut cdf = Vec::with_capacity(num * num);
153+
for a in 0..num {
154+
for b in 0..num {
155+
if a < b {
156+
let pa = pmf1[a] * frac;
157+
let pab = pa * pmf1[b] / (sum - pmf1[a]);
158+
159+
let pb = pmf1[b] * frac;
160+
let pba = pb * pmf1[a] / (sum - pmf1[b]);
161+
162+
ac += pab + pba;
163+
}
164+
cdf.push(ac);
165+
}
166+
}
167+
assert!((cdf.last().unwrap() - 1.0).abs() < 1e-9);
168+
169+
let cdf = |i| {
170+
if i < 0 {
171+
0.0
172+
} else {
173+
cdf[i as usize]
174+
}
175+
};
176+
177+
test_discrete(0, distr, cdf);
178+
}
179+
180+
test_weights(100, |_| 1.0);
181+
test_weights(100, |i| ((i + 1) as f64).ln());
182+
test_weights(100, |i| i as f64);
183+
test_weights(100, |i| (i as f64).powi(3));
184+
test_weights(100, |i| 1.0 / ((i + 1) as f64));
185+
test_weights(10, |i| ((i + 1) as f64).powi(-8));
186+
}
187+
188+
#[test]
189+
fn choose_iterator() {
190+
struct Adapter<I>(I);
191+
impl<I: Clone + Iterator<Item = i64>> Distribution<i64> for Adapter<I> {
192+
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> i64 {
193+
IteratorRandom::choose(self.0.clone(), rng).unwrap()
194+
}
195+
}
196+
197+
let distr = Adapter((0..100).map(|i| i as i64));
198+
test_discrete(0, distr, make_cdf(100, |_| 1.0));
199+
}
200+
201+
#[test]
202+
fn choose_stable_iterator() {
203+
struct Adapter<I>(I);
204+
impl<I: Clone + Iterator<Item = i64>> Distribution<i64> for Adapter<I> {
205+
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> i64 {
206+
IteratorRandom::choose_stable(self.0.clone(), rng).unwrap()
207+
}
208+
}
209+
210+
let distr = Adapter((0..100).map(|i| i as i64));
211+
test_discrete(0, distr, make_cdf(100, |_| 1.0));
212+
}
213+
214+
#[test]
215+
fn choose_two_iterator() {
216+
struct Adapter<I>(I);
217+
impl<I: Clone + Iterator<Item = i64>> Distribution<i64> for Adapter<I> {
218+
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> i64 {
219+
let mut buf = [0; 2];
220+
IteratorRandom::choose_multiple_fill(self.0.clone(), rng, &mut buf);
221+
buf.sort_unstable();
222+
assert!(buf[0] < 99 && buf[1] >= 1);
223+
let a = buf[0];
224+
4950 - (99 - a) * (100 - a) / 2 + buf[1] - a - 1
225+
}
226+
}
227+
228+
let distr = Adapter((0..100).map(|i| i as i64));
229+
230+
test_discrete(
231+
0,
232+
distr,
233+
|i| if i < 0 { 0.0 } else { (i + 1) as f64 / 4950.0 },
234+
);
235+
}

src/seq/index.rs

+33-21
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,8 @@ where
333333
/// ordering). The weights are to be provided by the input function `weights`,
334334
/// which will be called once for each index.
335335
///
336-
/// This implementation uses the algorithm described by Efraimidis and Spirakis
337-
/// in this paper: <https://doi.org/10.1016/j.ipl.2005.11.003>
336+
/// This implementation is based on the algorithm A-ExpJ as found in
337+
/// [Efraimidis and Spirakis, 2005](https://doi.org/10.1016/j.ipl.2005.11.003).
338338
/// It uses `O(length + amount)` space and `O(length)` time.
339339
///
340340
/// Error cases:
@@ -354,7 +354,7 @@ where
354354
N: UInt,
355355
IndexVec: From<Vec<N>>,
356356
{
357-
use std::cmp::Ordering;
357+
use std::{cmp::Ordering, collections::BinaryHeap};
358358

359359
if amount == N::zero() {
360360
return Ok(IndexVec::U32(Vec::new()));
@@ -373,9 +373,9 @@ where
373373

374374
impl<N> Ord for Element<N> {
375375
fn cmp(&self, other: &Self) -> Ordering {
376-
// partial_cmp will always produce a value,
377-
// because we check that the weights are not nan
378-
self.key.partial_cmp(&other.key).unwrap()
376+
// unwrap() should not panic since weights should not be NaN
377+
// We reverse so that BinaryHeap::peek shows the smallest item
378+
self.key.partial_cmp(&other.key).unwrap().reverse()
379379
}
380380
}
381381

@@ -387,12 +387,14 @@ where
387387

388388
impl<N> Eq for Element<N> {}
389389

390-
let mut candidates = Vec::with_capacity(length.as_usize());
390+
let mut candidates = BinaryHeap::with_capacity(amount.as_usize());
391391
let mut index = N::zero();
392-
while index < length {
392+
while index < length && candidates.len() < amount.as_usize() {
393393
let weight = weight(index.as_usize()).into();
394394
if weight > 0.0 {
395-
let key = rng.random::<f64>().powf(1.0 / weight);
395+
// We use the log of the key used in A-ExpJ to improve precision
396+
// for small weights:
397+
let key = rng.random::<f64>().ln() / weight;
396398
candidates.push(Element { index, key });
397399
} else if !(weight >= 0.0) {
398400
return Err(WeightError::InvalidWeight);
@@ -401,23 +403,33 @@ where
401403
index += N::one();
402404
}
403405

404-
let avail = candidates.len();
405-
if avail < amount.as_usize() {
406+
if candidates.len() < amount.as_usize() {
406407
return Err(WeightError::InsufficientNonZero);
407408
}
408409

409-
// Partially sort the array to find the `amount` elements with the greatest
410-
// keys. Do this by using `select_nth_unstable` to put the elements with
411-
// the *smallest* keys at the beginning of the list in `O(n)` time, which
412-
// provides equivalent information about the elements with the *greatest* keys.
413-
let (_, mid, greater) = candidates.select_nth_unstable(avail - amount.as_usize());
410+
let mut x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
411+
while index < length {
412+
let weight = weight(index.as_usize()).into();
413+
if weight > 0.0 {
414+
x -= weight;
415+
if x <= 0.0 {
416+
let min_candidate = candidates.pop().unwrap();
417+
let t = (min_candidate.key * weight).exp();
418+
let key = rng.random_range(t..1.0).ln() / weight;
419+
candidates.push(Element { index, key });
420+
421+
x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
422+
}
423+
} else if !(weight >= 0.0) {
424+
return Err(WeightError::InvalidWeight);
425+
}
414426

415-
let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
416-
result.push(mid.index);
417-
for element in greater {
418-
result.push(element.index);
427+
index += N::one();
419428
}
420-
Ok(IndexVec::from(result))
429+
430+
Ok(IndexVec::from(
431+
candidates.iter().map(|elt| elt.index).collect(),
432+
))
421433
}
422434

423435
/// Randomly sample exactly `amount` indices from `0..length`, using Floyd's

src/seq/slice.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -732,17 +732,17 @@ mod test {
732732
use super::*;
733733

734734
// The theoretical probabilities of the different outcomes are:
735-
// AB: 0.5 * 0.5 = 0.250
736-
// AC: 0.5 * 0.5 = 0.250
737-
// BA: 0.25 * 0.67 = 0.167
738-
// BC: 0.25 * 0.33 = 0.082
739-
// CA: 0.25 * 0.67 = 0.167
740-
// CB: 0.25 * 0.33 = 0.082
741-
let choices = [('a', 2), ('b', 1), ('c', 1)];
735+
// AB: 0.5 * 0.667 = 0.3333
736+
// AC: 0.5 * 0.333 = 0.1667
737+
// BA: 0.333 * 0.75 = 0.25
738+
// BC: 0.333 * 0.25 = 0.0833
739+
// CA: 0.167 * 0.6 = 0.1
740+
// CB: 0.167 * 0.4 = 0.0667
741+
let choices = [('a', 3), ('b', 2), ('c', 1)];
742742
let mut rng = crate::test::rng(414);
743743

744744
let mut results = [0i32; 3];
745-
let expected_results = [4167, 4167, 1666];
745+
let expected_results = [5833, 2667, 1500];
746746
for _ in 0..10000 {
747747
let result = choices
748748
.choose_multiple_weighted(&mut rng, 2, |item| item.1)

0 commit comments

Comments
 (0)