Skip to content

Bootstrap Aggregation #229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions algorithms/linfa-ensemble/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[package]
name = "linfa-ensemble"
version = "0.6.0"
edition = "2018"
authors = ["James Knight <[email protected]>", "James Kay <[email protected]>"]
description = "A general method for creating ensemble classifiers"
license = "MIT/Apache-2.0"

repository = "https://github.com/rust-ml/linfa"
readme = "README.md"

keywords = ["machine-learning", "linfa", "ensemble"]
categories = ["algorithms", "mathematics", "science"]

[features]
default = []
serde = ["serde_crate", "ndarray/serde"]

[dependencies.serde_crate]
package = "serde"
optional = true
version = "1.0"
default-features = false
features = ["std", "derive"]

[dependencies]
ndarray = { version = "0.15" , features = ["rayon", "approx"]}
ndarray-rand = "0.14"
rand = "0.8.5"

linfa = { version = "0.6.0", path = "../.." }
linfa-trees = { version = "0.6.0", path = "../linfa-trees"}

[dev-dependencies]
linfa-datasets = { version = "0.6.0", path = "../../datasets/", features = ["iris"] }

21 changes: 21 additions & 0 deletions algorithms/linfa-ensemble/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Enseble Learning

`linfa-ensemble` provides pure Rust implementations of Ensemble Learning algorithms for the Linfa toolkit.

## The Big Picture

`linfa-ensemble` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`.

## Current state

`linfa-ensemble` currently provides an implementation of bootstrap aggregation (bagging) for other classifers provided in linfa.

## Examples

You can find examples in the `examples/` directory. To run an bootstrap aggregation for ensemble of decision trees (a Random Forest) use:

```bash
$ cargo run --example randomforest_iris --release
```


35 changes: 35 additions & 0 deletions algorithms/linfa-ensemble/examples/randomforest_iris.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
use linfa_ensemble::EnsembleLearnerParams;
use linfa_trees::DecisionTree;
use ndarray_rand::rand::SeedableRng;
use rand::rngs::SmallRng;

fn main() {
//Number of models in the ensemble
let ensemble_size = 100;
//Proportion of training data given to each model
let bootstrap_proportion = 0.7;

//Load dataset
let mut rng = SmallRng::seed_from_u64(42);
let (train, test) = linfa_datasets::iris()
.shuffle(&mut rng)
.split_with_ratio(0.8);

//Train ensemble learner model
let model = EnsembleLearnerParams::new(DecisionTree::params())
.ensemble_size(ensemble_size)
.bootstrap_proportion(bootstrap_proportion)
.fit(&train)
.unwrap();

//Return highest ranking predictions
let final_predictions_ensemble = model.predict(&test);
println!("Final Predictions: \n{:?}", final_predictions_ensemble);

let cm = final_predictions_ensemble.confusion_matrix(&test).unwrap();

println!("{:?}", cm);
println!("Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {},\n Bootstrap Proportion: {}",
100.0 * cm.accuracy(), ensemble_size, bootstrap_proportion);
}
198 changes: 198 additions & 0 deletions algorithms/linfa-ensemble/src/ensemble.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
use linfa::{
dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records},
error::{Error, Result},
traits::*,
DatasetBase, ParamGuard,
};
use ndarray::{Array, Array2, Axis, Dimension};
use rand::rngs::ThreadRng;
use rand::Rng;
use std::{cmp::Eq, collections::HashMap, hash::Hash};

pub struct EnsembleLearner<M> {
pub models: Vec<M>,
}

impl<M> EnsembleLearner<M> {
// Generates prediction iterator returning predictions from each model
pub fn generate_predictions<'b, R: Records, T>(
&'b self,
x: &'b R,
) -> impl Iterator<Item = T> + 'b
where
M: Predict<&'b R, T>,
{
self.models.iter().map(move |m| m.predict(x))

Check warning on line 25 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L25

Added line #L25 was not covered by tests
}

// Consumes prediction iterator to return all predictions
pub fn aggregate_predictions<Ys: Iterator>(
&self,
ys: Ys,
) -> impl Iterator<
Item = Vec<(
Array<
<Ys::Item as AsTargets>::Elem,
<<Ys::Item as AsTargets>::Ix as Dimension>::Smaller,
>,
usize,
)>,
>
where
Ys::Item: AsTargets,
<Ys::Item as AsTargets>::Elem: Copy + Eq + Hash,
{
let mut prediction_maps = Vec::new();

Check warning on line 45 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L45

Added line #L45 was not covered by tests

for y in ys {
let targets = y.as_targets();
let no_targets = targets.shape()[0];

Check warning on line 49 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L47-L49

Added lines #L47 - L49 were not covered by tests

for i in 0..no_targets {
if prediction_maps.len() == i {
prediction_maps.push(HashMap::new());

Check warning on line 53 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L51-L53

Added lines #L51 - L53 were not covered by tests
}
*prediction_maps[i]
.entry(y.as_targets().index_axis(Axis(0), i).to_owned())
.or_insert(0) += 1;

Check warning on line 57 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L55-L57

Added lines #L55 - L57 were not covered by tests
}
}

prediction_maps.into_iter().map(|xs| {
let mut xs: Vec<_> = xs.into_iter().collect();
xs.sort_by(|(_, x), (_, y)| y.cmp(x));
xs

Check warning on line 64 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L61-L64

Added lines #L61 - L64 were not covered by tests
})
}
}

impl<F: Clone, T, M> PredictInplace<Array2<F>, T> for EnsembleLearner<M>
where
M: PredictInplace<Array2<F>, T>,
<T as AsTargets>::Elem: Copy + Eq + Hash,
T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>,
{
fn predict_inplace(&self, x: &Array2<F>, y: &mut T) {
let mut y_array = y.as_targets_mut();
assert_eq!(
x.nrows(),
y_array.len_of(Axis(0)),

Check warning on line 79 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L75-L79

Added lines #L75 - L79 were not covered by tests
"The number of data points must match the number of outputs."
);

let mut predictions = self.generate_predictions(x);
let aggregated_predictions = self.aggregate_predictions(&mut predictions);

Check warning on line 84 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L83-L84

Added lines #L83 - L84 were not covered by tests

for (target, output) in y_array
.axis_iter_mut(Axis(0))
.zip(aggregated_predictions.into_iter())

Check warning on line 88 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L86-L88

Added lines #L86 - L88 were not covered by tests
{
for (t, o) in target.into_iter().zip(output[0].0.iter()) {
*t = *o;

Check warning on line 91 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L90-L91

Added lines #L90 - L91 were not covered by tests
}
}
}

fn default_target(&self, x: &Array2<F>) -> T {
self.models[0].default_target(x)

Check warning on line 97 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L96-L97

Added lines #L96 - L97 were not covered by tests
}
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct EnsembleLearnerValidParams<P, R> {
pub ensemble_size: usize,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a separate field for ensemble_size? Isn't this value implied by bootstrap_proportion?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensemble_size gives the number of models in the ensemble while bootstrap_proportion gives the proportion of the total number of training samples that should be given to each model for training. These should be distinct parameters.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't bootstrap_proportion be the same as 1/ensemble_size?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily, each model in the ensemble just needs its own random set of samples of training data from the complete training data set. There are no constraints on the size of this set other than it being non-empty, so we let the user tune this size as a hyperparameter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK so bootstrap_samples just grabs random sets of samples from the input and yields them infinitely. I thought it divided the input into random subsamples. This makes sense now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this behaviour to the docs, along with a general description of EnsembleLearner? We should also have top level docs in src/lib.rs like with the other crates.

pub bootstrap_proportion: f64,
pub model_params: P,
pub rng: R,
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);

impl<P> EnsembleLearnerParams<P, ThreadRng> {
pub fn new(model_params: P) -> EnsembleLearnerParams<P, ThreadRng> {
return Self::new_fixed_rng(model_params, rand::thread_rng());

Check warning on line 114 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L113-L114

Added lines #L113 - L114 were not covered by tests
}
}

impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams<P, R> {
Self(EnsembleLearnerValidParams {
ensemble_size: 1,
bootstrap_proportion: 1.0,
model_params: model_params,
rng: rng,

Check warning on line 124 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L119-L124

Added lines #L119 - L124 were not covered by tests
})
}

pub fn ensemble_size(mut self, size: usize) -> Self {
self.0.ensemble_size = size;
self

Check warning on line 130 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L128-L130

Added lines #L128 - L130 were not covered by tests
}

pub fn bootstrap_proportion(mut self, proportion: f64) -> Self {
self.0.bootstrap_proportion = proportion;
self

Check warning on line 135 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L133-L135

Added lines #L133 - L135 were not covered by tests
}
}

impl<P, R> ParamGuard for EnsembleLearnerParams<P, R> {
type Checked = EnsembleLearnerValidParams<P, R>;
type Error = Error;

fn check_ref(&self) -> Result<&Self::Checked> {
if self.0.bootstrap_proportion > 1.0 || self.0.bootstrap_proportion <= 0.0 {
Err(Error::Parameters(format!(
"Bootstrap proportion should be greater than zero and less than or equal to one, but was {}",
self.0.bootstrap_proportion

Check warning on line 147 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L143-L147

Added lines #L143 - L147 were not covered by tests
)))
} else if self.0.ensemble_size < 1 {
Err(Error::Parameters(format!(
"Ensemble size should be less than one, but was {}",
self.0.ensemble_size

Check warning on line 152 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L149-L152

Added lines #L149 - L152 were not covered by tests
)))
} else {
Ok(&self.0)

Check warning on line 155 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L155

Added line #L155 was not covered by tests
}
}

fn check(self) -> Result<Self::Checked> {
self.check_ref()?;
Ok(self.0)

Check warning on line 161 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L159-L161

Added lines #L159 - L161 were not covered by tests
}
}

impl<D, T, P: Fit<Array2<D>, T::Owned, Error>, R: Rng + Clone> Fit<Array2<D>, T, Error>
for EnsembleLearnerValidParams<P, R>
where
D: Clone,
T: FromTargetArrayOwned,
T::Elem: Copy + Eq + Hash,
T::Owned: AsTargets,
{
type Object = EnsembleLearner<P::Object>;

fn fit(
&self,
dataset: &DatasetBase<Array2<D>, T>,
) -> core::result::Result<Self::Object, Error> {
let mut models = Vec::new();
let mut rng = self.rng.clone();

Check warning on line 180 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L179-L180

Added lines #L179 - L180 were not covered by tests

let dataset_size =
((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize;

Check warning on line 183 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L182-L183

Added lines #L182 - L183 were not covered by tests

let iter = dataset.bootstrap_samples(dataset_size, &mut rng);

Check warning on line 185 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L185

Added line #L185 was not covered by tests

for train in iter {
let model = self.model_params.fit(&train).unwrap();
models.push(model);

Check warning on line 189 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L187-L189

Added lines #L187 - L189 were not covered by tests

if models.len() == self.ensemble_size {
break;

Check warning on line 192 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L191-L192

Added lines #L191 - L192 were not covered by tests
}
}

Ok(EnsembleLearner { models })

Check warning on line 196 in algorithms/linfa-ensemble/src/ensemble.rs

View check run for this annotation

Codecov / codecov/patch

algorithms/linfa-ensemble/src/ensemble.rs#L196

Added line #L196 was not covered by tests
}
}
3 changes: 3 additions & 0 deletions algorithms/linfa-ensemble/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod ensemble;

pub use ensemble::*;
10 changes: 5 additions & 5 deletions src/dataset/impl_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{
super::traits::{Predict, PredictInplace},
iter::{ChunksIter, DatasetIter, Iter},
AsSingleTargets, AsTargets, AsTargetsMut, CountedTargets, Dataset, DatasetBase, DatasetView,
Float, FromTargetArray, Label, Labels, Records, Result, TargetDim,
Float, FromTargetArray, FromTargetArrayOwned, Label, Labels, Records, Result, TargetDim,
};
use crate::traits::Fit;
use ndarray::{concatenate, prelude::*, Data, DataMut, Dimension};
Expand Down Expand Up @@ -418,7 +418,7 @@ where
impl<'b, F: Clone, E: Copy + 'b, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
where
D: Data<Elem = F>,
T: FromTargetArray<'b, Elem = E>,
T: FromTargetArrayOwned<Elem = E>,
T::Owned: AsTargets,
{
/// Apply bootstrapping for samples and features
Expand All @@ -441,7 +441,7 @@ where
&'b self,
sample_feature_size: (usize, usize),
rng: &'b mut R,
) -> impl Iterator<Item = DatasetBase<Array2<F>, <T as FromTargetArray<'b>>::Owned>> + 'b {
) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
std::iter::repeat(()).map(move |_| {
// sample with replacement
let indices = (0..sample_feature_size.0)
Expand Down Expand Up @@ -481,7 +481,7 @@ where
&'b self,
num_samples: usize,
rng: &'b mut R,
) -> impl Iterator<Item = DatasetBase<Array2<F>, <T as FromTargetArray<'b>>::Owned>> + 'b {
) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
std::iter::repeat(()).map(move |_| {
// sample with replacement
let indices = (0..num_samples)
Expand Down Expand Up @@ -515,7 +515,7 @@ where
&'b self,
num_features: usize,
rng: &'b mut R,
) -> impl Iterator<Item = DatasetBase<Array2<F>, <T as FromTargetArray<'b>>::Owned>> + 'b {
) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
std::iter::repeat(()).map(move |_| {
let targets = T::new_targets(self.as_targets().to_owned());

Expand Down
Loading