-
-
Notifications
You must be signed in to change notification settings - Fork 276
Bootstrap Aggregation #229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
[package] | ||
name = "linfa-ensemble" | ||
version = "0.6.0" | ||
edition = "2018" | ||
authors = ["James Knight <[email protected]>", "James Kay <[email protected]>"] | ||
description = "A general method for creating ensemble classifiers" | ||
license = "MIT/Apache-2.0" | ||
|
||
repository = "https://github.com/rust-ml/linfa" | ||
readme = "README.md" | ||
|
||
keywords = ["machine-learning", "linfa", "ensemble"] | ||
categories = ["algorithms", "mathematics", "science"] | ||
|
||
[features] | ||
default = [] | ||
serde = ["serde_crate", "ndarray/serde"] | ||
|
||
[dependencies.serde_crate] | ||
package = "serde" | ||
optional = true | ||
version = "1.0" | ||
default-features = false | ||
features = ["std", "derive"] | ||
|
||
[dependencies] | ||
ndarray = { version = "0.15" , features = ["rayon", "approx"]} | ||
ndarray-rand = "0.14" | ||
rand = "0.8.5" | ||
|
||
linfa = { version = "0.6.0", path = "../.." } | ||
linfa-trees = { version = "0.6.0", path = "../linfa-trees"} | ||
|
||
[dev-dependencies] | ||
linfa-datasets = { version = "0.6.0", path = "../../datasets/", features = ["iris"] } | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Enseble Learning | ||
|
||
`linfa-ensemble` provides pure Rust implementations of Ensemble Learning algorithms for the Linfa toolkit. | ||
|
||
## The Big Picture | ||
|
||
`linfa-ensemble` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`. | ||
|
||
## Current state | ||
|
||
`linfa-ensemble` currently provides an implementation of bootstrap aggregation (bagging) for other classifers provided in linfa. | ||
|
||
## Examples | ||
|
||
You can find examples in the `examples/` directory. To run an bootstrap aggregation for ensemble of decision trees (a Random Forest) use: | ||
|
||
```bash | ||
$ cargo run --example randomforest_iris --release | ||
``` | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
use linfa::prelude::{Fit, Predict, ToConfusionMatrix}; | ||
use linfa_ensemble::EnsembleLearnerParams; | ||
use linfa_trees::DecisionTree; | ||
use ndarray_rand::rand::SeedableRng; | ||
use rand::rngs::SmallRng; | ||
|
||
fn main() { | ||
//Number of models in the ensemble | ||
let ensemble_size = 100; | ||
//Proportion of training data given to each model | ||
let bootstrap_proportion = 0.7; | ||
|
||
//Load dataset | ||
let mut rng = SmallRng::seed_from_u64(42); | ||
let (train, test) = linfa_datasets::iris() | ||
.shuffle(&mut rng) | ||
.split_with_ratio(0.8); | ||
|
||
//Train ensemble learner model | ||
let model = EnsembleLearnerParams::new(DecisionTree::params()) | ||
.ensemble_size(ensemble_size) | ||
.bootstrap_proportion(bootstrap_proportion) | ||
.fit(&train) | ||
.unwrap(); | ||
|
||
//Return highest ranking predictions | ||
let final_predictions_ensemble = model.predict(&test); | ||
println!("Final Predictions: \n{:?}", final_predictions_ensemble); | ||
|
||
let cm = final_predictions_ensemble.confusion_matrix(&test).unwrap(); | ||
|
||
println!("{:?}", cm); | ||
println!("Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {},\n Bootstrap Proportion: {}", | ||
100.0 * cm.accuracy(), ensemble_size, bootstrap_proportion); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
use linfa::{ | ||
dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records}, | ||
error::{Error, Result}, | ||
traits::*, | ||
DatasetBase, ParamGuard, | ||
}; | ||
use ndarray::{Array, Array2, Axis, Dimension}; | ||
use rand::rngs::ThreadRng; | ||
use rand::Rng; | ||
use std::{cmp::Eq, collections::HashMap, hash::Hash}; | ||
|
||
pub struct EnsembleLearner<M> { | ||
pub models: Vec<M>, | ||
} | ||
|
||
impl<M> EnsembleLearner<M> { | ||
// Generates prediction iterator returning predictions from each model | ||
pub fn generate_predictions<'b, R: Records, T>( | ||
&'b self, | ||
x: &'b R, | ||
) -> impl Iterator<Item = T> + 'b | ||
where | ||
M: Predict<&'b R, T>, | ||
{ | ||
self.models.iter().map(move |m| m.predict(x)) | ||
} | ||
|
||
// Consumes prediction iterator to return all predictions | ||
pub fn aggregate_predictions<Ys: Iterator>( | ||
&self, | ||
ys: Ys, | ||
) -> impl Iterator< | ||
Item = Vec<( | ||
Array< | ||
<Ys::Item as AsTargets>::Elem, | ||
<<Ys::Item as AsTargets>::Ix as Dimension>::Smaller, | ||
>, | ||
usize, | ||
)>, | ||
> | ||
where | ||
Ys::Item: AsTargets, | ||
<Ys::Item as AsTargets>::Elem: Copy + Eq + Hash, | ||
{ | ||
let mut prediction_maps = Vec::new(); | ||
|
||
for y in ys { | ||
let targets = y.as_targets(); | ||
let no_targets = targets.shape()[0]; | ||
|
||
for i in 0..no_targets { | ||
if prediction_maps.len() == i { | ||
prediction_maps.push(HashMap::new()); | ||
} | ||
*prediction_maps[i] | ||
.entry(y.as_targets().index_axis(Axis(0), i).to_owned()) | ||
.or_insert(0) += 1; | ||
} | ||
} | ||
|
||
prediction_maps.into_iter().map(|xs| { | ||
let mut xs: Vec<_> = xs.into_iter().collect(); | ||
xs.sort_by(|(_, x), (_, y)| y.cmp(x)); | ||
xs | ||
}) | ||
} | ||
} | ||
|
||
impl<F: Clone, T, M> PredictInplace<Array2<F>, T> for EnsembleLearner<M> | ||
where | ||
M: PredictInplace<Array2<F>, T>, | ||
<T as AsTargets>::Elem: Copy + Eq + Hash, | ||
T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>, | ||
{ | ||
fn predict_inplace(&self, x: &Array2<F>, y: &mut T) { | ||
let mut y_array = y.as_targets_mut(); | ||
assert_eq!( | ||
x.nrows(), | ||
y_array.len_of(Axis(0)), | ||
"The number of data points must match the number of outputs." | ||
); | ||
|
||
let mut predictions = self.generate_predictions(x); | ||
let aggregated_predictions = self.aggregate_predictions(&mut predictions); | ||
|
||
for (target, output) in y_array | ||
.axis_iter_mut(Axis(0)) | ||
.zip(aggregated_predictions.into_iter()) | ||
{ | ||
for (t, o) in target.into_iter().zip(output[0].0.iter()) { | ||
*t = *o; | ||
} | ||
} | ||
} | ||
|
||
fn default_target(&self, x: &Array2<F>) -> T { | ||
self.models[0].default_target(x) | ||
} | ||
} | ||
|
||
#[derive(Clone, Copy, Debug, PartialEq)] | ||
pub struct EnsembleLearnerValidParams<P, R> { | ||
pub ensemble_size: usize, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need a separate field for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not necessarily, each model in the ensemble just needs its own random set of samples of training data from the complete training data set. There are no constraints on the size of this set other than it being non-empty, so we let the user tune this size as a hyperparameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add this behaviour to the docs, along with a general description of |
||
pub bootstrap_proportion: f64, | ||
pub model_params: P, | ||
pub rng: R, | ||
} | ||
|
||
#[derive(Clone, Copy, Debug, PartialEq)] | ||
pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>); | ||
|
||
impl<P> EnsembleLearnerParams<P, ThreadRng> { | ||
pub fn new(model_params: P) -> EnsembleLearnerParams<P, ThreadRng> { | ||
return Self::new_fixed_rng(model_params, rand::thread_rng()); | ||
} | ||
} | ||
|
||
impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> { | ||
pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams<P, R> { | ||
Self(EnsembleLearnerValidParams { | ||
ensemble_size: 1, | ||
bootstrap_proportion: 1.0, | ||
model_params: model_params, | ||
rng: rng, | ||
}) | ||
} | ||
|
||
pub fn ensemble_size(mut self, size: usize) -> Self { | ||
self.0.ensemble_size = size; | ||
self | ||
} | ||
|
||
pub fn bootstrap_proportion(mut self, proportion: f64) -> Self { | ||
self.0.bootstrap_proportion = proportion; | ||
self | ||
} | ||
} | ||
|
||
impl<P, R> ParamGuard for EnsembleLearnerParams<P, R> { | ||
type Checked = EnsembleLearnerValidParams<P, R>; | ||
type Error = Error; | ||
|
||
fn check_ref(&self) -> Result<&Self::Checked> { | ||
if self.0.bootstrap_proportion > 1.0 || self.0.bootstrap_proportion <= 0.0 { | ||
Err(Error::Parameters(format!( | ||
"Bootstrap proportion should be greater than zero and less than or equal to one, but was {}", | ||
self.0.bootstrap_proportion | ||
))) | ||
} else if self.0.ensemble_size < 1 { | ||
Err(Error::Parameters(format!( | ||
"Ensemble size should be less than one, but was {}", | ||
self.0.ensemble_size | ||
))) | ||
} else { | ||
Ok(&self.0) | ||
} | ||
} | ||
|
||
fn check(self) -> Result<Self::Checked> { | ||
self.check_ref()?; | ||
Ok(self.0) | ||
} | ||
} | ||
|
||
impl<D, T, P: Fit<Array2<D>, T::Owned, Error>, R: Rng + Clone> Fit<Array2<D>, T, Error> | ||
for EnsembleLearnerValidParams<P, R> | ||
where | ||
D: Clone, | ||
T: FromTargetArrayOwned, | ||
T::Elem: Copy + Eq + Hash, | ||
T::Owned: AsTargets, | ||
{ | ||
type Object = EnsembleLearner<P::Object>; | ||
|
||
fn fit( | ||
&self, | ||
dataset: &DatasetBase<Array2<D>, T>, | ||
) -> core::result::Result<Self::Object, Error> { | ||
let mut models = Vec::new(); | ||
let mut rng = self.rng.clone(); | ||
|
||
let dataset_size = | ||
((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize; | ||
|
||
let iter = dataset.bootstrap_samples(dataset_size, &mut rng); | ||
|
||
for train in iter { | ||
let model = self.model_params.fit(&train).unwrap(); | ||
models.push(model); | ||
|
||
if models.len() == self.ensemble_size { | ||
break; | ||
} | ||
} | ||
|
||
Ok(EnsembleLearner { models }) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
mod ensemble; | ||
|
||
pub use ensemble::*; |
Uh oh!
There was an error while loading. Please reload this page.