Skip to content

Commit b4b1eb7

Browse files
authored
Re-org with distr::slice, distr::weighted modules (#1548)
- Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty` - Rename trait `DistString` -> `SampleString` - Rename `DistIter` -> `Iter`, `DistMap` -> `Map` - Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight, Error, WeightedIndex}` - Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` -> `weighted::{..}` - Move `weighted_tree::WeightedTreeIndex` -> `weighted::WeightedTreeIndex`
1 parent 16eb7de commit b4b1eb7

23 files changed

+354
-334
lines changed

.github/workflows/benches.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ defaults:
2020

2121
jobs:
2222
clippy-fmt:
23-
name: Check Clippy and rustfmt
23+
name: "Benches: Check Clippy and rustfmt"
2424
runs-on: ubuntu-latest
2525
steps:
2626
- uses: actions/checkout@v4
@@ -33,7 +33,7 @@ jobs:
3333
- name: Clippy
3434
run: cargo clippy --all-targets -- -D warnings
3535
benches:
36-
name: Test benchmarks
36+
name: "Benches: Test"
3737
runs-on: ubuntu-latest
3838
steps:
3939
- uses: actions/checkout@v4

.github/workflows/distr_test.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ defaults:
2020

2121
jobs:
2222
clippy-fmt:
23-
name: Check Clippy and rustfmt
23+
name: "distr_test: Check Clippy and rustfmt"
2424
runs-on: ubuntu-latest
2525
steps:
2626
- uses: actions/checkout@v4
@@ -33,7 +33,7 @@ jobs:
3333
- name: Clippy
3434
run: cargo clippy --all-targets -- -D warnings
3535
ks-tests:
36-
name: Run Komogorov Smirnov tests
36+
name: "distr_test: Run Komogorov Smirnov tests"
3737
runs-on: ubuntu-latest
3838
steps:
3939
- uses: actions/checkout@v4

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
toolchain: stable
2929
components: clippy, rustfmt
3030
- name: Check Clippy
31-
run: cargo clippy --all --all-targets -- -D warnings
31+
run: cargo clippy --workspace -- -D warnings
3232
- name: Check rustfmt
3333
run: cargo fmt --all -- --check
3434

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
1010

1111
## [0.9.0-beta.3] - 2025-01-03
1212
- Add feature `thread_rng` (#1547)
13+
- Move `distr::Slice` -> `distr::slice::Choose`, `distr::EmptySlice` -> `distr::slice::Empty` (#1548)
14+
- Rename trait `distr::DistString` -> `distr::SampleString` (#1548)
15+
- Rename `distr::DistIter` -> `distr::Iter`, `distr::DistMap` -> `distr::Map` (#1548)
16+
- Move `distr::{Weight, WeightError, WeightedIndex}` -> `distr::weighted::{Weight, Error, WeightedIndex}` (#1548)
1317

1418
## [0.9.0-beta.1] - 2024-11-30
1519
- Bump `rand_core` version

benches/benches/distr.rs

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput};
1010
use criterion_cycles_per_byte::CyclesPerByte;
1111

1212
use rand::prelude::*;
13+
use rand_distr::weighted::*;
1314
use rand_distr::*;
1415

1516
// At this time, distributions are optimised for 64-bit platforms.

benches/benches/weighted.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// except according to those terms.
88

99
use criterion::{black_box, criterion_group, criterion_main, Criterion};
10-
use rand::distr::WeightedIndex;
10+
use rand::distr::weighted::WeightedIndex;
1111
use rand::prelude::*;
1212
use rand::seq::index::sample_weighted;
1313

distr_test/tests/weighted.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
mod ks;
1010
use ks::test_discrete;
11-
use rand::distr::{Distribution, WeightedIndex};
11+
use rand::distr::Distribution;
1212
use rand::seq::{IndexedRandom, IteratorRandom};
13-
use rand_distr::{WeightedAliasIndex, WeightedTreeIndex};
13+
use rand_distr::weighted::*;
1414

1515
/// Takes the unnormalized pdf and creates the cdf of a discrete distribution
1616
fn make_cdf(num: usize, f: impl Fn(i64) -> f64) -> impl Fn(i64) -> f64 {

rand_distr/CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [0.5.0-beta.3] - 2025-01-03
88
- Bump `rand` version (#1547)
9+
- Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty` (#1548)
10+
- Rename trait `DistString` -> `SampleString` (#1548)
11+
- Rename `DistIter` -> `Iter`, `DistMap` -> `Map` (#1548)
12+
- Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight, Error, WeightedIndex}` (#1548)
13+
- Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` -> `weighted::{..}` (#1548)
14+
- Move `weighted_tree::WeightedTreeIndex` -> `weighted::WeightedTreeIndex` (#1548)
915

1016
## [0.5.0-beta.2] - 2024-11-30
1117
- Bump `rand` version

rand_distr/src/lib.rs

+7-17
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
//!
3434
//! The following are re-exported:
3535
//!
36-
//! - The [`Distribution`] trait and [`DistIter`] helper type
36+
//! - The [`Distribution`] trait and [`Iter`] helper type
3737
//! - The [`StandardUniform`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`],
38-
//! [`Open01`], [`Bernoulli`], and [`WeightedIndex`] distributions
38+
//! [`Open01`], [`Bernoulli`] distributions
39+
//! - The [`weighted`] module
3940
//!
4041
//! ## Distributions
4142
//!
@@ -76,9 +77,6 @@
7677
//! - [`UnitBall`] distribution
7778
//! - [`UnitCircle`] distribution
7879
//! - [`UnitDisc`] distribution
79-
//! - Alternative implementations for weighted index sampling
80-
//! - [`WeightedAliasIndex`] distribution
81-
//! - [`WeightedTreeIndex`] distribution
8280
//! - Misc. distributions
8381
//! - [`InverseGaussian`] distribution
8482
//! - [`NormalInverseGaussian`] distribution
@@ -94,7 +92,7 @@ extern crate std;
9492
use rand::Rng;
9593

9694
pub use rand::distr::{
97-
uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01,
95+
uniform, Alphanumeric, Bernoulli, BernoulliError, Distribution, Iter, Open01, OpenClosed01,
9896
StandardUniform, Uniform,
9997
};
10098

@@ -128,16 +126,13 @@ pub use self::unit_sphere::UnitSphere;
128126
pub use self::weibull::{Error as WeibullError, Weibull};
129127
pub use self::zeta::{Error as ZetaError, Zeta};
130128
pub use self::zipf::{Error as ZipfError, Zipf};
131-
#[cfg(feature = "alloc")]
132-
pub use rand::distr::{WeightError, WeightedIndex};
133129
pub use student_t::StudentT;
134-
#[cfg(feature = "alloc")]
135-
pub use weighted_alias::WeightedAliasIndex;
136-
#[cfg(feature = "alloc")]
137-
pub use weighted_tree::WeightedTreeIndex;
138130

139131
pub use num_traits;
140132

133+
#[cfg(feature = "alloc")]
134+
pub mod weighted;
135+
141136
#[cfg(test)]
142137
#[macro_use]
143138
mod test {
@@ -189,11 +184,6 @@ mod test {
189184
}
190185
}
191186

192-
#[cfg(feature = "alloc")]
193-
pub mod weighted_alias;
194-
#[cfg(feature = "alloc")]
195-
pub mod weighted_tree;
196-
197187
mod beta;
198188
mod binomial;
199189
mod cauchy;

rand_distr/src/weighted/mod.rs

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright 2018 Developers of the Rand project.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
//! Weighted (index) sampling
10+
//!
11+
//! This module is a superset of [`rand::distr::weighted`].
12+
//!
13+
//! Multiple implementations of weighted index sampling are provided:
14+
//!
15+
//! - [`WeightedIndex`] (a re-export from [`rand`]) supports fast construction
16+
//! and `O(log N)` sampling over `N` weights.
17+
//! It also supports updating weights with `O(N)` time.
18+
//! - [`WeightedAliasIndex`] supports `O(1)` sampling, but due to high
19+
//! construction time many samples are required to outperform [`WeightedIndex`].
20+
//! - [`WeightedTreeIndex`] supports `O(log N)` sampling and
21+
//! update/insertion/removal of weights with `O(log N)` time.
22+
23+
mod weighted_alias;
24+
mod weighted_tree;
25+
26+
pub use rand::distr::weighted::*;
27+
pub use weighted_alias::*;
28+
pub use weighted_tree::*;

rand_distr/src/weighted_alias.rs rand_distr/src/weighted/weighted_alias.rs

+21-21
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
//! This module contains an implementation of alias method for sampling random
1010
//! indices with probabilities proportional to a collection of weights.
1111
12-
use super::WeightError;
12+
use super::Error;
1313
use crate::{uniform::SampleUniform, Distribution, Uniform};
1414
use alloc::{boxed::Box, vec, vec::Vec};
1515
use core::fmt;
@@ -41,7 +41,7 @@ use serde::{Deserialize, Serialize};
4141
/// # Example
4242
///
4343
/// ```
44-
/// use rand_distr::WeightedAliasIndex;
44+
/// use rand_distr::weighted::WeightedAliasIndex;
4545
/// use rand::prelude::*;
4646
///
4747
/// let choices = vec!['a', 'b', 'c'];
@@ -85,14 +85,14 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
8585
/// Creates a new [`WeightedAliasIndex`].
8686
///
8787
/// Error cases:
88-
/// - [`WeightError::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
89-
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number,
88+
/// - [`Error::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
89+
/// - [`Error::InvalidWeight`] when a weight is not-a-number,
9090
/// negative or greater than `max = W::MAX / weights.len()`.
91-
/// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero.
92-
pub fn new(weights: Vec<W>) -> Result<Self, WeightError> {
91+
/// - [`Error::InsufficientNonZero`] when the sum of all weights is zero.
92+
pub fn new(weights: Vec<W>) -> Result<Self, Error> {
9393
let n = weights.len();
9494
if n == 0 || n > u32::MAX as usize {
95-
return Err(WeightError::InvalidInput);
95+
return Err(Error::InvalidInput);
9696
}
9797
let n = n as u32;
9898

@@ -103,7 +103,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
103103
.iter()
104104
.all(|&w| W::ZERO <= w && w <= max_weight_size)
105105
{
106-
return Err(WeightError::InvalidWeight);
106+
return Err(Error::InvalidWeight);
107107
}
108108

109109
// The sum of weights will represent 100% of no alias odds.
@@ -115,7 +115,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
115115
weight_sum
116116
};
117117
if weight_sum == W::ZERO {
118-
return Err(WeightError::InsufficientNonZero);
118+
return Err(Error::InsufficientNonZero);
119119
}
120120

121121
// `weight_sum` would have been zero if `try_from_lossy` causes an error here.
@@ -384,23 +384,23 @@ mod test {
384384
// Floating point special cases
385385
assert_eq!(
386386
WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(),
387-
WeightError::InvalidWeight
387+
Error::InvalidWeight
388388
);
389389
assert_eq!(
390390
WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
391-
WeightError::InsufficientNonZero
391+
Error::InsufficientNonZero
392392
);
393393
assert_eq!(
394394
WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
395-
WeightError::InvalidWeight
395+
Error::InvalidWeight
396396
);
397397
assert_eq!(
398398
WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(),
399-
WeightError::InvalidWeight
399+
Error::InvalidWeight
400400
);
401401
assert_eq!(
402402
WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(),
403-
WeightError::InvalidWeight
403+
Error::InvalidWeight
404404
);
405405
}
406406

@@ -418,11 +418,11 @@ mod test {
418418
// Signed integer special cases
419419
assert_eq!(
420420
WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
421-
WeightError::InvalidWeight
421+
Error::InvalidWeight
422422
);
423423
assert_eq!(
424424
WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(),
425-
WeightError::InvalidWeight
425+
Error::InvalidWeight
426426
);
427427
}
428428

@@ -440,11 +440,11 @@ mod test {
440440
// Signed integer special cases
441441
assert_eq!(
442442
WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
443-
WeightError::InvalidWeight
443+
Error::InvalidWeight
444444
);
445445
assert_eq!(
446446
WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(),
447-
WeightError::InvalidWeight
447+
Error::InvalidWeight
448448
);
449449
}
450450

@@ -491,15 +491,15 @@ mod test {
491491

492492
assert_eq!(
493493
WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
494-
WeightError::InvalidInput
494+
Error::InvalidInput
495495
);
496496
assert_eq!(
497497
WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
498-
WeightError::InsufficientNonZero
498+
Error::InsufficientNonZero
499499
);
500500
assert_eq!(
501501
WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
502-
WeightError::InvalidWeight
502+
Error::InvalidWeight
503503
);
504504
}
505505

0 commit comments

Comments
 (0)