From 1345273126848b60cd2a34ceec3fb36b33515b29 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 14:54:46 +0100 Subject: [PATCH 01/35] Initialize project --- .gitignore | 3 +++ Cargo.toml | 7 +++++++ src/lib.rs | 7 +++++++ 3 files changed, 17 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2f88dba --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +**/*.rs.bk +Cargo.lock \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..7410715 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "linfa" +version = "0.1.0" +authors = ["LukeMathWalker "] +edition = "2018" + +[dependencies] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..31e1bb2 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,7 @@ +#[cfg(test)] +mod tests { + #[test] + fn it_works() { + assert_eq!(2 + 2, 4); + } +} From c602eeccf4515460587cfe32cf13df31a0383236 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 15:01:11 +0100 Subject: [PATCH 02/35] Start from rusty-machine approach --- src/lib.rs | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 31e1bb2..067994b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,17 @@ -#[cfg(test)] -mod tests { - #[test] - fn it_works() { - assert_eq!(2 + 2, 4); - } +use std::error::Error; + +pub trait Trainer + where M: Model +{ + type Input; + type Target; + + fn train(self, inputs: &Self::Input) -> Result; } + +pub trait Model { + type Input; + type Output; + + fn predict(&self, inputs: &Self::Input) -> Result; +} \ No newline at end of file From 7407331f672fb3e3f28f1d15bd76d3256ad4f20f Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 15:15:24 +0100 Subject: [PATCH 03/35] Explain rationale --- src/lib.rs | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 067994b..a76916d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,31 @@ -use std::error::Error; +use std::error; pub trait Trainer where M: Model { type Input; type Target; + type Error: error::Error; - fn train(self, inputs: &Self::Input) -> Result; + fn train(self, inputs: &Self::Input) -> Result; } +/// The basic `Model` trait. +/// +/// It is training-agnostic: a model takes an input and returns an output. +/// +/// There might be multiple ways to discover the best settings for every +/// particular algorithm (e.g. training a logistic regressor using +/// a pseudo-inverse matrix vs using gradient descent). +/// It doesn't matter: the end result, the model, is a set of parameters. +/// The way those parameter originated is an orthogonal concept. +/// +/// In the same way, it has no notion of loss or "correct" predictions. +/// Those concepts are embedded elsewhere. pub trait Model { type Input; type Output; + type Error: error::Error; - fn predict(&self, inputs: &Self::Input) -> Result; + fn predict(&self, inputs: &Self::Input) -> Result; } \ No newline at end of file From a808405ffc9ca15035e7a17f468a1a4c35d5d0ac Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 16:01:19 +0100 Subject: [PATCH 04/35] Restructure the trainer class --- src/lib.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a76916d..2e8d09d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,11 +3,10 @@ use std::error; pub trait Trainer where M: Model { - type Input; - type Target; type Error: error::Error; - fn train(self, inputs: &Self::Input) -> Result; + fn train(&self, inputs: &M::Input, targets: &M::Output, loss: L) -> Result + where L: FnMut(&M::Output, &M::Output) -> f64; } /// The basic `Model` trait. From 866b1ed521eb1ef6a3df029879ba28270aeded78 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 16:02:12 +0100 Subject: [PATCH 05/35] Cargo fmt + comment --- src/lib.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2e8d09d..96f667d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,15 @@ use std::error; pub trait Trainer - where M: Model +where + M: Model, { type Error: error::Error; fn train(&self, inputs: &M::Input, targets: &M::Output, loss: L) -> Result - where L: FnMut(&M::Output, &M::Output) -> f64; + // Returning f64 is arbitrary, but I didn't want to flesh out a Loss trait yet + where + L: FnMut(&M::Output, &M::Output) -> f64; } /// The basic `Model` trait. @@ -27,4 +30,4 @@ pub trait Model { type Error: error::Error; fn predict(&self, inputs: &Self::Input) -> Result; -} \ No newline at end of file +} From 0fa4e2bd576e9bfbca5fe1cb357fb3ae0d7eed98 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 16:10:37 +0100 Subject: [PATCH 06/35] Add rationale for Optimizer trait (former Trainer) --- src/lib.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 96f667d..d24219c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,15 @@ use std::error; -pub trait Trainer +/// Where information is distilled from data. +/// +/// `Optimizer` is generic over a type `M` implementing the `Model` trait: `M` is used to +/// constrain what type of inputs and targets are acceptable, as well as what signature the +/// loss function should have. +/// +/// The output of the loss function is currently unconstrained: should it be an associated +/// type of the `Optimizer` trait itself? Should we add it as a generic parameter of the +/// `train` method, with a set of reasonable trait bounds? +pub trait Optimizer where M: Model, { From 6bc7a71d9e3e4b237cead71090fa8aae62b71426 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 17:38:23 +0100 Subject: [PATCH 07/35] Introduce Blueprint trait and add details to Optimizer docs. --- src/lib.rs | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index d24219c..25cf30c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,14 @@ use std::error; /// constrain what type of inputs and targets are acceptable, as well as what signature the /// loss function should have. /// +/// `train` takes an instance of `M` as one of its inputs, `model`: it doesn't matter if `model` +/// has been through several rounds of training before, or if it just came out of a `Blueprint` +/// using `initialize` - it's consumed by `train` and a new model is returned. +/// +/// This means that there is no difference between one-shot training and incremental training. +/// Furthermore, the optimizer doesn't have to "own" the model or know anything about its hyperparameters, +/// because it never has to initialize it. +/// /// The output of the loss function is currently unconstrained: should it be an associated /// type of the `Optimizer` trait itself? Should we add it as a generic parameter of the /// `train` method, with a set of reasonable trait bounds? @@ -15,12 +23,32 @@ where { type Error: error::Error; - fn train(&self, inputs: &M::Input, targets: &M::Output, loss: L) -> Result + fn train(&self, inputs: &M::Input, targets: &M::Output, model: M, loss: L) -> Result // Returning f64 is arbitrary, but I didn't want to flesh out a Loss trait yet where L: FnMut(&M::Output, &M::Output) -> f64; } +/// Where `Model`s are forged. +/// +/// `Blueprint`s are used to specify how to build and initialize an instance of the model type `M`. +/// +/// For the same model type `M`, nothing prevents a user from providing more than one `Blueprint`: +/// multiple initialization strategies can somethings be used to be build the same model type. +/// +/// Each of these strategies can take different (hyper)parameters, even though they return an +/// instance of the same model type in the end. +/// +/// The initialization procedure could be data-dependent, hence the signature of `initialize`. +pub trait Blueprint +where + M: Model, +{ + type Error: error::Error; + + fn initialize(&self, inputs: &M::Input, targets: &M::Output) -> Result; +} + /// The basic `Model` trait. /// /// It is training-agnostic: a model takes an input and returns an output. From 60e649b52bfce64f3e2a01c06057541944a4d436 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 17:41:05 +0100 Subject: [PATCH 08/35] Cargo fmt --- src/lib.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 25cf30c..c7b28f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,13 @@ where { type Error: error::Error; - fn train(&self, inputs: &M::Input, targets: &M::Output, model: M, loss: L) -> Result + fn train( + &self, + inputs: &M::Input, + targets: &M::Output, + model: M, + loss: L, + ) -> Result // Returning f64 is arbitrary, but I didn't want to flesh out a Loss trait yet where L: FnMut(&M::Output, &M::Output) -> f64; From 1c4bcc4c38e63e4a5417c24fc48c9928dfd3c7b4 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 17:48:18 +0100 Subject: [PATCH 09/35] Ignore IDE-related files --- .gitignore | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2f88dba..291d057 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ /target **/*.rs.bk -Cargo.lock \ No newline at end of file +Cargo.lock + +# IDEs +.idea/ +tags From 5d03be2ff97da1c0609d8f9c2d260aa75b62d825 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 18:02:43 +0100 Subject: [PATCH 10/35] Remove the loss parameter from Optimizer: for most models it's not possible to just "drop in" a loss function. --- src/lib.rs | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index c7b28f0..668bb6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,10 +13,6 @@ use std::error; /// This means that there is no difference between one-shot training and incremental training. /// Furthermore, the optimizer doesn't have to "own" the model or know anything about its hyperparameters, /// because it never has to initialize it. -/// -/// The output of the loss function is currently unconstrained: should it be an associated -/// type of the `Optimizer` trait itself? Should we add it as a generic parameter of the -/// `train` method, with a set of reasonable trait bounds? pub trait Optimizer where M: Model, @@ -28,11 +24,7 @@ where inputs: &M::Input, targets: &M::Output, model: M, - loss: L, ) -> Result - // Returning f64 is arbitrary, but I didn't want to flesh out a Loss trait yet - where - L: FnMut(&M::Output, &M::Output) -> f64; } /// Where `Model`s are forged. From 2b0bbf6db6ce187e0948f0c8774ee6137b360d98 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 18:02:48 +0100 Subject: [PATCH 11/35] Typo --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 668bb6c..47bcc59 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ where inputs: &M::Input, targets: &M::Output, model: M, - ) -> Result + ) -> Result; } /// Where `Model`s are forged. From d56a5a009e13cb057d26a1e373ea797d9daa349f Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 11 May 2019 18:14:28 +0100 Subject: [PATCH 12/35] Typos --- src/lib.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 47bcc59..f5673bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,9 @@ use std::error; -/// Where information is distilled from data. +/// One step closer to the peak. /// /// `Optimizer` is generic over a type `M` implementing the `Model` trait: `M` is used to -/// constrain what type of inputs and targets are acceptable, as well as what signature the -/// loss function should have. +/// constrain what type of inputs and targets are acceptable. /// /// `train` takes an instance of `M` as one of its inputs, `model`: it doesn't matter if `model` /// has been through several rounds of training before, or if it just came out of a `Blueprint` @@ -19,7 +18,7 @@ where { type Error: error::Error; - fn train( + fn train( &self, inputs: &M::Input, targets: &M::Output, From fbe9a6659598a02f77464f1be2bf19320c4a6de2 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 12 May 2019 17:40:39 +0100 Subject: [PATCH 13/35] Add BlueprintGenerator --- src/lib.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index f5673bb..ad75b00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,6 +46,28 @@ where fn initialize(&self, inputs: &M::Input, targets: &M::Output) -> Result; } + +/// Where you need to go meta (hyperparameters!). +/// +/// `BlueprintGenerator`s can be used to explore different combination of hyperparameters +/// when you are working with a certain `Model` type. +/// +/// `BlueprintGenerator::generate` takes as input a closure that returns a `Blueprint` instance +/// and an iterator that yields a set of possible inputs for this closure. It returns, +/// if successful, an `IntoIterator` type yielding instances of blueprints. +pub trait BlueprintGenerator +where + B: Blueprint, + M: Model +{ + type Error: error::Error; + + fn generate(&self, parametrization: F, params: &Iterator) -> Result + where + F: FnMut(P) -> B, + I: IntoIterator; +} + /// The basic `Model` trait. /// /// It is training-agnostic: a model takes an input and returns an output. From 6f1b906a4216027292e0d7dd77f095b834e63fbd Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 12 May 2019 18:20:49 +0100 Subject: [PATCH 14/35] Remove parameters from generate --- src/lib.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ad75b00..ff62a2f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,9 +52,8 @@ where /// `BlueprintGenerator`s can be used to explore different combination of hyperparameters /// when you are working with a certain `Model` type. /// -/// `BlueprintGenerator::generate` takes as input a closure that returns a `Blueprint` instance -/// and an iterator that yields a set of possible inputs for this closure. It returns, -/// if successful, an `IntoIterator` type yielding instances of blueprints. +/// `BlueprintGenerator::generate` returns, if successful, an `IntoIterator` type +/// yielding instances of blueprints. pub trait BlueprintGenerator where B: Blueprint, @@ -62,9 +61,8 @@ where { type Error: error::Error; - fn generate(&self, parametrization: F, params: &Iterator) -> Result + fn generate(&self) -> Result where - F: FnMut(P) -> B, I: IntoIterator; } From f1a0312ef59af93ccace06ad2fb010ae0fa9f63d Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 12 May 2019 18:32:06 +0100 Subject: [PATCH 15/35] Refine BlueprintGenerator, moving I to associated type. Implement BlueprintGenerator for Blueprints --- src/lib.rs | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ff62a2f..d413678 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ use std::error; +use std::iter; /// One step closer to the peak. /// @@ -46,6 +47,22 @@ where fn initialize(&self, inputs: &M::Input, targets: &M::Output) -> Result; } +/// Any `Blueprint` can be used as `BlueprintGenerator`, as long as it's clonable: +/// it returns an iterator with a single element, a clone of itself. +impl BlueprintGenerator for B +where + B: Blueprint + Clone, + M: Model, +{ + type Error = B::Error; + type Output = iter::Once; + + fn generate(&self) -> Result + { + Ok(iter::once(self.clone())) + } +} + /// Where you need to go meta (hyperparameters!). /// @@ -60,10 +77,9 @@ where M: Model { type Error: error::Error; + type Output: IntoIterator; - fn generate(&self) -> Result - where - I: IntoIterator; + fn generate(&self) -> Result; } /// The basic `Model` trait. From d39a6e3c624393eb5d043327762cc0e4ae1b97bf Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 12 May 2019 18:34:27 +0100 Subject: [PATCH 16/35] Re-org --- src/lib.rs | 67 +++++++++++++++++++++++++++--------------------------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d413678..310665e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,26 @@ use std::error; use std::iter; +/// The basic `Model` trait. +/// +/// It is training-agnostic: a model takes an input and returns an output. +/// +/// There might be multiple ways to discover the best settings for every +/// particular algorithm (e.g. training a logistic regressor using +/// a pseudo-inverse matrix vs using gradient descent). +/// It doesn't matter: the end result, the model, is a set of parameters. +/// The way those parameter originated is an orthogonal concept. +/// +/// In the same way, it has no notion of loss or "correct" predictions. +/// Those concepts are embedded elsewhere. +pub trait Model { + type Input; + type Output; + type Error: error::Error; + + fn predict(&self, inputs: &Self::Input) -> Result; +} + /// One step closer to the peak. /// /// `Optimizer` is generic over a type `M` implementing the `Model` trait: `M` is used to @@ -47,23 +67,6 @@ where fn initialize(&self, inputs: &M::Input, targets: &M::Output) -> Result; } -/// Any `Blueprint` can be used as `BlueprintGenerator`, as long as it's clonable: -/// it returns an iterator with a single element, a clone of itself. -impl BlueprintGenerator for B -where - B: Blueprint + Clone, - M: Model, -{ - type Error = B::Error; - type Output = iter::Once; - - fn generate(&self) -> Result - { - Ok(iter::once(self.clone())) - } -} - - /// Where you need to go meta (hyperparameters!). /// /// `BlueprintGenerator`s can be used to explore different combination of hyperparameters @@ -82,22 +85,18 @@ where fn generate(&self) -> Result; } -/// The basic `Model` trait. -/// -/// It is training-agnostic: a model takes an input and returns an output. -/// -/// There might be multiple ways to discover the best settings for every -/// particular algorithm (e.g. training a logistic regressor using -/// a pseudo-inverse matrix vs using gradient descent). -/// It doesn't matter: the end result, the model, is a set of parameters. -/// The way those parameter originated is an orthogonal concept. -/// -/// In the same way, it has no notion of loss or "correct" predictions. -/// Those concepts are embedded elsewhere. -pub trait Model { - type Input; - type Output; - type Error: error::Error; +/// Any `Blueprint` can be used as `BlueprintGenerator`, as long as it's clonable: +/// it returns an iterator with a single element, a clone of itself. +impl BlueprintGenerator for B + where + B: Blueprint + Clone, + M: Model, +{ + type Error = B::Error; + type Output = iter::Once; - fn predict(&self, inputs: &Self::Input) -> Result; + fn generate(&self) -> Result + { + Ok(iter::once(self.clone())) + } } From cbfa74cbee3d7737261887a4cd1401402d95f635 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 12 May 2019 18:37:14 +0100 Subject: [PATCH 17/35] Add blanket implementation of Blueprint for Model types --- src/lib.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 310665e..4b0b1c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,6 +67,20 @@ where fn initialize(&self, inputs: &M::Input, targets: &M::Output) -> Result; } +/// Any `Model` can be used as `Blueprint`, as long as it's clonable: +/// it returns a clone of itself when `initialize` is called, ignoring the data. +impl Blueprint for M +where + M: Model + Clone, +{ + type Error = M::Error; + + fn initialize(&self, _inputs: &M::Input, _targets: &M::Output) -> Result + { + Ok(self.clone()) + } +} + /// Where you need to go meta (hyperparameters!). /// /// `BlueprintGenerator`s can be used to explore different combination of hyperparameters From afebfe1c55fa0413d596a4504f2c5670ac013d01 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 15 May 2019 08:42:53 +0100 Subject: [PATCH 18/35] Refactor --- src/lib.rs | 94 +++++++++++++++++++++++++----------------------------- 1 file changed, 43 insertions(+), 51 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4b0b1c1..9358bf7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,84 +1,78 @@ use std::error; use std::iter; -/// The basic `Model` trait. +/// The basic `Transformer` trait. /// -/// It is training-agnostic: a model takes an input and returns an output. +/// It is training-agnostic: a transformer takes an input and returns an output. /// /// There might be multiple ways to discover the best settings for every /// particular algorithm (e.g. training a logistic regressor using /// a pseudo-inverse matrix vs using gradient descent). -/// It doesn't matter: the end result, the model, is a set of parameters. +/// It doesn't matter: the end result, the transformer, is a set of parameters. /// The way those parameter originated is an orthogonal concept. /// /// In the same way, it has no notion of loss or "correct" predictions. /// Those concepts are embedded elsewhere. -pub trait Model { +pub trait Transformer { type Input; type Output; type Error: error::Error; - fn predict(&self, inputs: &Self::Input) -> Result; + fn transform(&self, inputs: &Self::Input) -> Result; } /// One step closer to the peak. /// -/// `Optimizer` is generic over a type `M` implementing the `Model` trait: `M` is used to +/// `Fit` is generic over a type `B` implementing the `Blueprint` trait: `B::Transformer` is used to /// constrain what type of inputs and targets are acceptable. /// -/// `train` takes an instance of `M` as one of its inputs, `model`: it doesn't matter if `model` -/// has been through several rounds of training before, or if it just came out of a `Blueprint` -/// using `initialize` - it's consumed by `train` and a new model is returned. +/// `fit` takes an instance of `B` as one of its inputs, `blueprint`: it's consumed with move +/// semantics and a new transformer is returned. /// -/// This means that there is no difference between one-shot training and incremental training. -/// Furthermore, the optimizer doesn't have to "own" the model or know anything about its hyperparameters, -/// because it never has to initialize it. -pub trait Optimizer +/// It's a transition in the transformer state machine: from `Blueprint` to `Transformer`. +pub trait Fit where - M: Model, + B: Blueprint, { type Error: error::Error; - fn train( + fn fit( &self, - inputs: &M::Input, - targets: &M::Output, - model: M, - ) -> Result; + inputs: &B::Transformer::Input, + targets: &B::Transformer::Output, + blueprint: B, + ) -> Result; } -/// Where `Model`s are forged. -/// -/// `Blueprint`s are used to specify how to build and initialize an instance of the model type `M`. -/// -/// For the same model type `M`, nothing prevents a user from providing more than one `Blueprint`: -/// multiple initialization strategies can somethings be used to be build the same model type. -/// -/// Each of these strategies can take different (hyper)parameters, even though they return an -/// instance of the same model type in the end. -/// -/// The initialization procedure could be data-dependent, hence the signature of `initialize`. -pub trait Blueprint +pub trait IncrementalFit where - M: Model, + T: Transformer { type Error: error::Error; - fn initialize(&self, inputs: &M::Input, targets: &M::Output) -> Result; -} + fn incremental_fit( + &self, + inputs: &T::Input, + targets: &T::Output, + transformer: T, + ) -> Result; -/// Any `Model` can be used as `Blueprint`, as long as it's clonable: -/// it returns a clone of itself when `initialize` is called, ignoring the data. -impl Blueprint for M -where - M: Model + Clone, -{ - type Error = M::Error; +} - fn initialize(&self, _inputs: &M::Input, _targets: &M::Output) -> Result - { - Ok(self.clone()) - } +/// Where `Transformer`s are forged. +/// +/// `Blueprint` is a marker trait: it identifies what types can be used as starting points for +/// building `Transformer`s. It's the initial stage of the transformer state machine. +/// +/// Every `Blueprint` is associated to a single `Transformer` type (is it wise to do so?). +/// +/// For the same transformer type `T`, nothing prevents a user from providing more than one `Blueprint`: +/// multiple initialization strategies can sometimes be used to be build the same model type. +/// +/// Each of these strategies can take different (hyper)parameters, even though they return an +/// instance of the same model type in the end. +pub trait Blueprint { + type Transformer: Transformer; } /// Where you need to go meta (hyperparameters!). @@ -88,10 +82,9 @@ where /// /// `BlueprintGenerator::generate` returns, if successful, an `IntoIterator` type /// yielding instances of blueprints. -pub trait BlueprintGenerator +pub trait BlueprintGenerator where - B: Blueprint, - M: Model + B: Blueprint, { type Error: error::Error; type Output: IntoIterator; @@ -101,10 +94,9 @@ where /// Any `Blueprint` can be used as `BlueprintGenerator`, as long as it's clonable: /// it returns an iterator with a single element, a clone of itself. -impl BlueprintGenerator for B +impl BlueprintGenerator for B where - B: Blueprint + Clone, - M: Model, + B: Blueprint + Clone, { type Error = B::Error; type Output = iter::Once; From 7a2c4a9b1e964ab33ed494d9aa8a18043deb0b24 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 15 May 2019 08:45:09 +0100 Subject: [PATCH 19/35] Add docsc, fix typos --- src/lib.rs | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9358bf7..ed53f15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,6 +44,15 @@ where ) -> Result; } +/// We are not done with that `Transformer` yet. +/// +/// `IncrementalFit` is generic over a type `T` implementing the `Transformer` trait: `T` is used to +/// constrain what type of inputs and targets are acceptable. +/// +/// `incremental_fit` takes an instance of `T` as one of its inputs, `transformer`: it's consumed with move +/// semantics and a new transformer is returned. +/// +/// It's a transition in the transformer state machine: from `Transformer` to `Transformer`. pub trait IncrementalFit where T: Transformer @@ -67,10 +76,10 @@ where /// Every `Blueprint` is associated to a single `Transformer` type (is it wise to do so?). /// /// For the same transformer type `T`, nothing prevents a user from providing more than one `Blueprint`: -/// multiple initialization strategies can sometimes be used to be build the same model type. +/// multiple initialization strategies can sometimes be used to be build the same transformer type. /// /// Each of these strategies can take different (hyper)parameters, even though they return an -/// instance of the same model type in the end. +/// instance of the same transformer type in the end. pub trait Blueprint { type Transformer: Transformer; } @@ -78,7 +87,7 @@ pub trait Blueprint { /// Where you need to go meta (hyperparameters!). /// /// `BlueprintGenerator`s can be used to explore different combination of hyperparameters -/// when you are working with a certain `Model` type. +/// when you are working with a certain `Transformer` type. /// /// `BlueprintGenerator::generate` returns, if successful, an `IntoIterator` type /// yielding instances of blueprints. From f7f8f50dda5c2114bcea80962e9b3d361e7f2c2b Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 15 May 2019 08:53:19 +0100 Subject: [PATCH 20/35] Make input and output generic parameters --- src/lib.rs | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ed53f15..38c7d11 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,12 +13,10 @@ use std::iter; /// /// In the same way, it has no notion of loss or "correct" predictions. /// Those concepts are embedded elsewhere. -pub trait Transformer { - type Input; - type Output; +pub trait Transformer { type Error: error::Error; - fn transform(&self, inputs: &Self::Input) -> Result; + fn transform(&self, inputs: &I) -> Result; } /// One step closer to the peak. @@ -30,16 +28,16 @@ pub trait Transformer { /// semantics and a new transformer is returned. /// /// It's a transition in the transformer state machine: from `Blueprint` to `Transformer`. -pub trait Fit +pub trait Fit where - B: Blueprint, + B: Blueprint, { type Error: error::Error; fn fit( &self, - inputs: &B::Transformer::Input, - targets: &B::Transformer::Output, + inputs: &I, + targets: &O, blueprint: B, ) -> Result; } @@ -53,16 +51,16 @@ where /// semantics and a new transformer is returned. /// /// It's a transition in the transformer state machine: from `Transformer` to `Transformer`. -pub trait IncrementalFit +pub trait IncrementalFit where - T: Transformer + T: Transformer { type Error: error::Error; fn incremental_fit( &self, - inputs: &T::Input, - targets: &T::Output, + inputs: &I, + targets: &O, transformer: T, ) -> Result; @@ -80,8 +78,8 @@ where /// /// Each of these strategies can take different (hyper)parameters, even though they return an /// instance of the same transformer type in the end. -pub trait Blueprint { - type Transformer: Transformer; +pub trait Blueprint { + type Transformer: Transformer; } /// Where you need to go meta (hyperparameters!). @@ -91,9 +89,9 @@ pub trait Blueprint { /// /// `BlueprintGenerator::generate` returns, if successful, an `IntoIterator` type /// yielding instances of blueprints. -pub trait BlueprintGenerator +pub trait BlueprintGenerator where - B: Blueprint, + B: Blueprint, { type Error: error::Error; type Output: IntoIterator; @@ -103,11 +101,12 @@ where /// Any `Blueprint` can be used as `BlueprintGenerator`, as long as it's clonable: /// it returns an iterator with a single element, a clone of itself. -impl BlueprintGenerator for B +impl BlueprintGenerator for B where - B: Blueprint + Clone, + B: Blueprint + Clone, { - type Error = B::Error; + // Random error, didn't have time to get a proper one + type Error = std::io::Error; type Output = iter::Once; fn generate(&self) -> Result From da9a9d73f76686a27ea6d88ed56d11c97412c37e Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 15 May 2019 08:56:11 +0100 Subject: [PATCH 21/35] Add comments --- src/lib.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 38c7d11..6c45f3e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,15 @@ use std::iter; /// /// In the same way, it has no notion of loss or "correct" predictions. /// Those concepts are embedded elsewhere. +/// +/// It's generic over input and output types: +/// - you can transform a fully in-memory dataset; +/// - you can transform a stream of data; +/// - you can return a class; +/// - you can return a probability distribution. +/// +/// The mechanism for selecting the desired output, when not self-evident from the downstream +/// usage, should be the same of the `::collect()` method. pub trait Transformer { type Error: error::Error; @@ -28,6 +37,12 @@ pub trait Transformer { /// semantics and a new transformer is returned. /// /// It's a transition in the transformer state machine: from `Blueprint` to `Transformer`. +/// +/// It's generic over input and output types: +/// - you can fit on a fully in-memory dataset; +/// - you can fit on a stream of data; +/// - you can use integer-encoded class membership as a target; +/// - you can use a one-hot-encoded class membership as a target. pub trait Fit where B: Blueprint, @@ -51,6 +66,12 @@ where /// semantics and a new transformer is returned. /// /// It's a transition in the transformer state machine: from `Transformer` to `Transformer`. +/// +/// It's generic over input and output types: +/// - you can fit on a fully in-memory dataset; +/// - you can fit on a stream of data; +/// - you can use integer-encoded class membership as a target; +/// - you can use a one-hot-encoded class membership as a target. pub trait IncrementalFit where T: Transformer From 834ffddd37cf229db3add53bafede449dbdfa570 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 15:45:56 +0100 Subject: [PATCH 22/35] Doc minor fix --- src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 6c45f3e..5fd8b4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,6 +36,8 @@ pub trait Transformer { /// `fit` takes an instance of `B` as one of its inputs, `blueprint`: it's consumed with move /// semantics and a new transformer is returned. /// +/// Different types implementing `Fit` can work on the same `Blueprint` type! +/// /// It's a transition in the transformer state machine: from `Blueprint` to `Transformer`. /// /// It's generic over input and output types: @@ -84,7 +86,6 @@ where targets: &O, transformer: T, ) -> Result; - } /// Where `Transformer`s are forged. From 9ffa8cdbff7cc3cfe47e2c0bdcd95fd21b0c655d Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 15:48:42 +0100 Subject: [PATCH 23/35] Add examples folder --- examples/basic.rs | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 examples/basic.rs diff --git a/examples/basic.rs b/examples/basic.rs new file mode 100644 index 0000000..646fa5e --- /dev/null +++ b/examples/basic.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello world!"); +} \ No newline at end of file From dfe62f432db5c0571cbae4f7994fc3b524700dc7 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 16:05:23 +0100 Subject: [PATCH 24/35] Basic transformer implementation for standard scaling --- Cargo.toml | 4 ++++ examples/basic.rs | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 7410715..7768afe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,3 +5,7 @@ authors = ["LukeMathWalker "] edition = "2018" [dependencies] + +[dev-dependencies] +ndarray = "0.12.1" +derive_more = "0.13.0" diff --git a/examples/basic.rs b/examples/basic.rs index 646fa5e..2771477 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,3 +1,36 @@ +extern crate linfa; +extern crate ndarray; +#[macro_use] +extern crate derive_more; + +use std::error::Error; +use linfa::Transformer; +use ndarray::{ArrayBase, Ix1, Data, Array1}; + +#[derive(Debug, Eq, PartialEq, From, Display)] +pub struct ScalingError {} + +impl Error for ScalingError {} + +pub struct StandardScaler { + mean: f64, + standard_deviation: f64, +} + +impl Transformer, Array1> for StandardScaler +where + S: Data, +{ + type Error = ScalingError; + + fn transform(&self, inputs: &ArrayBase) -> Result, Self::Error> + where + S: Data, + { + Ok((inputs - self.mean) / self.standard_deviation) + } +} + fn main() { println!("Hello world!"); } \ No newline at end of file From 93ad90af1f90ebbabcda8ded24fdfa4c5aa3b9d5 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 16:23:09 +0100 Subject: [PATCH 25/35] Add other structs --- examples/basic.rs | 38 +++++++++++++++++++++++++++++++++----- src/lib.rs | 25 +++++++------------------ 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/examples/basic.rs b/examples/basic.rs index 2771477..94b7d17 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -3,29 +3,57 @@ extern crate ndarray; #[macro_use] extern crate derive_more; +use linfa::{Transformer, Blueprint}; +use ndarray::{Array1, ArrayBase, Data, Ix1}; use std::error::Error; -use linfa::Transformer; -use ndarray::{ArrayBase, Ix1, Data, Array1}; +/// Fast-and-dirty error struct #[derive(Debug, Eq, PartialEq, From, Display)] pub struct ScalingError {} impl Error for ScalingError {} +/// Given an input, it rescales it to have zero mean and unit variance. +/// +/// We use 64-bit floats for simplicity. pub struct StandardScaler { + // Delta degrees of freedom. + // With ddof = 1, you get the sample standard deviation + // With ddof = 0, you get the population standard deviation + ddof: u8, mean: f64, standard_deviation: f64, } +/// It keeps track of the number of samples seen so far, to allow for +/// incremental computation of mean and standard deviation. +pub struct OnlineOptimizer { + n_samples: u64 +} + +pub struct Config { + // Delta degrees of freedom. + // With ddof = 1, you get the sample standard deviation + // With ddof = 0, you get the population standard deviation + ddof: u8 +} + +/// Defaults to computing the sample standard deviation. +impl Default for Config { + fn default() -> Self { + Self { ddof: 1 } + } +} + impl Transformer, Array1> for StandardScaler where - S: Data, + S: Data, { type Error = ScalingError; fn transform(&self, inputs: &ArrayBase) -> Result, Self::Error> where - S: Data, + S: Data, { Ok((inputs - self.mean) / self.standard_deviation) } @@ -33,4 +61,4 @@ where fn main() { println!("Hello world!"); -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index 5fd8b4f..7c437bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,12 +51,7 @@ where { type Error: error::Error; - fn fit( - &self, - inputs: &I, - targets: &O, - blueprint: B, - ) -> Result; + fn fit(&self, inputs: &I, targets: &O, blueprint: B) -> Result; } /// We are not done with that `Transformer` yet. @@ -76,16 +71,11 @@ where /// - you can use a one-hot-encoded class membership as a target. pub trait IncrementalFit where - T: Transformer + T: Transformer, { type Error: error::Error; - fn incremental_fit( - &self, - inputs: &I, - targets: &O, - transformer: T, - ) -> Result; + fn incremental_fit(&self, inputs: &I, targets: &O, transformer: T) -> Result; } /// Where `Transformer`s are forged. @@ -116,7 +106,7 @@ where B: Blueprint, { type Error: error::Error; - type Output: IntoIterator; + type Output: IntoIterator; fn generate(&self) -> Result; } @@ -124,15 +114,14 @@ where /// Any `Blueprint` can be used as `BlueprintGenerator`, as long as it's clonable: /// it returns an iterator with a single element, a clone of itself. impl BlueprintGenerator for B - where - B: Blueprint + Clone, +where + B: Blueprint + Clone, { // Random error, didn't have time to get a proper one type Error = std::io::Error; type Output = iter::Once; - fn generate(&self) -> Result - { + fn generate(&self) -> Result { Ok(iter::once(self.clone())) } } From 39b73797b4301f695c5b4d3914a5441acbaf848e Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 16:38:32 +0100 Subject: [PATCH 26/35] Skeleton of Fit and IncrementalFit implementation --- examples/basic.rs | 53 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/examples/basic.rs b/examples/basic.rs index 94b7d17..2311166 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -3,10 +3,14 @@ extern crate ndarray; #[macro_use] extern crate derive_more; -use linfa::{Transformer, Blueprint}; +use linfa::{Blueprint, Fit, IncrementalFit, Transformer}; use ndarray::{Array1, ArrayBase, Data, Ix1}; use std::error::Error; +/// Short-hand notations +type Input> = ArrayBase; +type Output = Array1; + /// Fast-and-dirty error struct #[derive(Debug, Eq, PartialEq, From, Display)] pub struct ScalingError {} @@ -28,14 +32,46 @@ pub struct StandardScaler { /// It keeps track of the number of samples seen so far, to allow for /// incremental computation of mean and standard deviation. pub struct OnlineOptimizer { - n_samples: u64 + n_samples: u64, +} + +impl Fit, Output> for OnlineOptimizer +where + S: Data, +{ + type Error = ScalingError; + + fn fit( + &self, + inputs: &Input, + targets: &Output, + blueprint: Config, + ) -> Result { + unimplemented!() + } +} + +impl IncrementalFit, Output> for OnlineOptimizer +where + S: Data, +{ + type Error = ScalingError; + + fn incremental_fit( + &self, + inputs: &Input, + targets: &Output, + transformer: StandardScaler, + ) -> Result { + unimplemented!() + } } pub struct Config { // Delta degrees of freedom. // With ddof = 1, you get the sample standard deviation // With ddof = 0, you get the population standard deviation - ddof: u8 + ddof: u8, } /// Defaults to computing the sample standard deviation. @@ -45,13 +81,20 @@ impl Default for Config { } } -impl Transformer, Array1> for StandardScaler +impl Blueprint, Output> for Config +where + S: Data, +{ + type Transformer = StandardScaler; +} + +impl Transformer, Output> for StandardScaler where S: Data, { type Error = ScalingError; - fn transform(&self, inputs: &ArrayBase) -> Result, Self::Error> + fn transform(&self, inputs: &Input) -> Result where S: Data, { From 6aa09e09ae6ddb9227342be90373168f1de2a211 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 16:49:34 +0100 Subject: [PATCH 27/35] Implemented Fit trait --- examples/basic.rs | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/basic.rs b/examples/basic.rs index 2311166..b6d0914 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -4,7 +4,7 @@ extern crate ndarray; extern crate derive_more; use linfa::{Blueprint, Fit, IncrementalFit, Transformer}; -use ndarray::{Array1, ArrayBase, Data, Ix1}; +use ndarray::{Array1, ArrayBase, Data, Ix1, Axis}; use std::error::Error; /// Short-hand notations @@ -24,9 +24,9 @@ pub struct StandardScaler { // Delta degrees of freedom. // With ddof = 1, you get the sample standard deviation // With ddof = 0, you get the population standard deviation - ddof: u8, - mean: f64, - standard_deviation: f64, + pub ddof: u8, + pub mean: f64, + pub standard_deviation: f64, } /// It keeps track of the number of samples seen so far, to allow for @@ -44,10 +44,19 @@ where fn fit( &self, inputs: &Input, - targets: &Output, + _targets: &Output, blueprint: Config, ) -> Result { - unimplemented!() + if inputs.len() == 0 { + return Err(ScalingError {}) + } + let mean = inputs.mean_axis(Axis(0)).into_scalar(); + let standard_deviation = inputs.std_axis(Axis(0), blueprint.ddof as f64).into_scalar(); + Ok(StandardScaler { + ddof: blueprint.ddof, + mean, + standard_deviation + }) } } From ace68b11c45557fdcb45f03ca4d39e1af938498d Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 16:54:21 +0100 Subject: [PATCH 28/35] Convert ddof to f64. Make fit and incremental_fit take self as mutable reference --- examples/basic.rs | 22 +++++++++++++--------- src/lib.rs | 10 ++++++++-- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/examples/basic.rs b/examples/basic.rs index b6d0914..0787fde 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -4,7 +4,7 @@ extern crate ndarray; extern crate derive_more; use linfa::{Blueprint, Fit, IncrementalFit, Transformer}; -use ndarray::{Array1, ArrayBase, Data, Ix1, Axis}; +use ndarray::{Array1, ArrayBase, Axis, Data, Ix1}; use std::error::Error; /// Short-hand notations @@ -24,7 +24,7 @@ pub struct StandardScaler { // Delta degrees of freedom. // With ddof = 1, you get the sample standard deviation // With ddof = 0, you get the population standard deviation - pub ddof: u8, + pub ddof: f64, pub mean: f64, pub standard_deviation: f64, } @@ -42,20 +42,24 @@ where type Error = ScalingError; fn fit( - &self, + &mut self, inputs: &Input, _targets: &Output, blueprint: Config, ) -> Result { if inputs.len() == 0 { - return Err(ScalingError {}) + return Err(ScalingError {}); } + // Compute relevant quantities let mean = inputs.mean_axis(Axis(0)).into_scalar(); - let standard_deviation = inputs.std_axis(Axis(0), blueprint.ddof as f64).into_scalar(); + let standard_deviation = inputs.std_axis(Axis(0), blueprint.ddof).into_scalar(); + // Initialize n_samples using the array length + self.n_samples = inputs.len() as u64; + // Return new, tuned scaler Ok(StandardScaler { ddof: blueprint.ddof, mean, - standard_deviation + standard_deviation, }) } } @@ -67,7 +71,7 @@ where type Error = ScalingError; fn incremental_fit( - &self, + &mut self, inputs: &Input, targets: &Output, transformer: StandardScaler, @@ -80,13 +84,13 @@ pub struct Config { // Delta degrees of freedom. // With ddof = 1, you get the sample standard deviation // With ddof = 0, you get the population standard deviation - ddof: u8, + ddof: f64, } /// Defaults to computing the sample standard deviation. impl Default for Config { fn default() -> Self { - Self { ddof: 1 } + Self { ddof: 1. } } } diff --git a/src/lib.rs b/src/lib.rs index 7c437bb..c9435fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,7 +51,8 @@ where { type Error: error::Error; - fn fit(&self, inputs: &I, targets: &O, blueprint: B) -> Result; + fn fit(&mut self, inputs: &I, targets: &O, blueprint: B) + -> Result; } /// We are not done with that `Transformer` yet. @@ -75,7 +76,12 @@ where { type Error: error::Error; - fn incremental_fit(&self, inputs: &I, targets: &O, transformer: T) -> Result; + fn incremental_fit( + &mut self, + inputs: &I, + targets: &O, + transformer: T, + ) -> Result; } /// Where `Transformer`s are forged. From d4735888b44d2eb86362cffdf5ac5c80fafda539 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 17:10:38 +0100 Subject: [PATCH 29/35] Implement IncrementalFit --- examples/basic.rs | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/examples/basic.rs b/examples/basic.rs index 0787fde..348a5ed 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -8,7 +8,7 @@ use ndarray::{Array1, ArrayBase, Axis, Data, Ix1}; use std::error::Error; /// Short-hand notations -type Input> = ArrayBase; +type Input = ArrayBase; type Output = Array1; /// Fast-and-dirty error struct @@ -73,10 +73,37 @@ where fn incremental_fit( &mut self, inputs: &Input, - targets: &Output, + _targets: &Output, transformer: StandardScaler, ) -> Result { - unimplemented!() + if inputs.len() == 0 { + // Nothing to be done + return Ok(transformer); + } + // Compute relevant quantities for the new batch + let batch_n_samples = inputs.len(); + let batch_mean = inputs.mean_axis(Axis(0)).into_scalar(); + let batch_std = inputs.std_axis(Axis(0), transformer.ddof).into_scalar(); + + // Update + let mean_delta = batch_mean - transformer.mean; + let new_n_samples = self.n_samples + (batch_n_samples as u64); + let new_mean = + transformer.mean + mean_delta * (batch_n_samples as f64) / (new_n_samples as f64); + let new_std = transformer.standard_deviation + + batch_std + + mean_delta.powi(2) * (self.n_samples as f64) * (batch_n_samples as f64) + / (new_n_samples as f64); + + // Update n_samples + self.n_samples = new_n_samples; + + // Return tuned scaler + Ok(StandardScaler { + ddof: transformer.ddof, + mean: new_mean, + standard_deviation: new_std, + }) } } From c494d3c2a95c5c9f0bb94625d632d80994f6684d Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 17:39:35 +0100 Subject: [PATCH 30/35] Add very basic usage example --- Cargo.toml | 2 ++ examples/basic.rs | 56 +++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7768afe..3008f0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,4 +8,6 @@ edition = "2018" [dev-dependencies] ndarray = "0.12.1" +ndarray-rand = "0.9.0" +rand = "*" derive_more = "0.13.0" diff --git a/examples/basic.rs b/examples/basic.rs index 348a5ed..edceb9b 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,10 +1,14 @@ extern crate linfa; extern crate ndarray; +extern crate ndarray_rand; +extern crate rand; #[macro_use] extern crate derive_more; use linfa::{Blueprint, Fit, IncrementalFit, Transformer}; -use ndarray::{Array1, ArrayBase, Axis, Data, Ix1}; +use ndarray::{Array1, ArrayBase, Axis, Data, Ix1, stack}; +use ndarray_rand::RandomExt; +use rand::distributions::Uniform; use std::error::Error; /// Short-hand notations @@ -32,7 +36,14 @@ pub struct StandardScaler { /// It keeps track of the number of samples seen so far, to allow for /// incremental computation of mean and standard deviation. pub struct OnlineOptimizer { - n_samples: u64, + pub n_samples: u64, +} + +/// Initialize n_samples to 0. +impl Default for OnlineOptimizer { + fn default() -> Self { + Self { n_samples: 0 } + } } impl Fit, Output> for OnlineOptimizer @@ -142,6 +153,43 @@ where } } -fn main() { - println!("Hello world!"); +fn generate_batch(n_samples: usize) -> (Array1, Array1) { + let distribution = Uniform::new(0., 10.); + let x = Array1::random(n_samples, distribution); + let y = Array1::random(n_samples, distribution); + (x, y) +} + +fn check(scaler: &StandardScaler, x: &ArrayBase) -> Result<(), ScalingError> +where + S: Data +{ + let old_batch_mean = x.mean_axis(Axis(0)).into_scalar(); + let new_batch_mean = scaler + .transform(&x)? + .mean_axis(Axis(0)) + .into_scalar(); + println!( + "The mean.\nBefore scaling: {:?}\nAfter scaling: {:?}\n", + old_batch_mean, new_batch_mean + ); + Ok(()) +} + +fn main() -> Result<(), ScalingError> { + let n_samples = 20; + let (x, y) = generate_batch(n_samples); + + let mut optimizer = OnlineOptimizer::default(); + let standard_scaler = optimizer.fit(&x, &y, Config::default())?; + + check(&standard_scaler, &x)?; + + let (x2, y2) = generate_batch(n_samples); + let standard_scaler = optimizer.incremental_fit(&x2, &y2, standard_scaler)?; + + let whole_x = stack(Axis(0), &[x.view(), x2.view()]).expect("Failed to stack arrays"); + check(&standard_scaler, &whole_x)?; + + Ok(()) } From 8616d44bad829d2803f3b60f36696e0d679d1d1a Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 17:40:32 +0100 Subject: [PATCH 31/35] Move into folder --- examples/{basic.rs => running_mean/main.rs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{basic.rs => running_mean/main.rs} (100%) diff --git a/examples/basic.rs b/examples/running_mean/main.rs similarity index 100% rename from examples/basic.rs rename to examples/running_mean/main.rs From 2a54ddd9e64cc6f36ba90be89744651a4a34fbc9 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 17:51:10 +0100 Subject: [PATCH 32/35] Restructure into a proper module --- examples/running_mean/main.rs | 147 +----------------- .../running_mean/standard_scaler/config.rs | 25 +++ examples/running_mean/standard_scaler/mod.rs | 45 ++++++ .../running_mean/standard_scaler/optimizer.rs | 89 +++++++++++ 4 files changed, 163 insertions(+), 143 deletions(-) create mode 100644 examples/running_mean/standard_scaler/config.rs create mode 100644 examples/running_mean/standard_scaler/mod.rs create mode 100644 examples/running_mean/standard_scaler/optimizer.rs diff --git a/examples/running_mean/main.rs b/examples/running_mean/main.rs index edceb9b..a07ca2a 100644 --- a/examples/running_mean/main.rs +++ b/examples/running_mean/main.rs @@ -5,153 +5,13 @@ extern crate rand; #[macro_use] extern crate derive_more; -use linfa::{Blueprint, Fit, IncrementalFit, Transformer}; +use linfa::{Fit, IncrementalFit, Transformer}; use ndarray::{Array1, ArrayBase, Axis, Data, Ix1, stack}; use ndarray_rand::RandomExt; use rand::distributions::Uniform; -use std::error::Error; +use crate::standard_scaler::{StandardScaler, ScalingError, Config, OnlineOptimizer}; -/// Short-hand notations -type Input = ArrayBase; -type Output = Array1; - -/// Fast-and-dirty error struct -#[derive(Debug, Eq, PartialEq, From, Display)] -pub struct ScalingError {} - -impl Error for ScalingError {} - -/// Given an input, it rescales it to have zero mean and unit variance. -/// -/// We use 64-bit floats for simplicity. -pub struct StandardScaler { - // Delta degrees of freedom. - // With ddof = 1, you get the sample standard deviation - // With ddof = 0, you get the population standard deviation - pub ddof: f64, - pub mean: f64, - pub standard_deviation: f64, -} - -/// It keeps track of the number of samples seen so far, to allow for -/// incremental computation of mean and standard deviation. -pub struct OnlineOptimizer { - pub n_samples: u64, -} - -/// Initialize n_samples to 0. -impl Default for OnlineOptimizer { - fn default() -> Self { - Self { n_samples: 0 } - } -} - -impl Fit, Output> for OnlineOptimizer -where - S: Data, -{ - type Error = ScalingError; - - fn fit( - &mut self, - inputs: &Input, - _targets: &Output, - blueprint: Config, - ) -> Result { - if inputs.len() == 0 { - return Err(ScalingError {}); - } - // Compute relevant quantities - let mean = inputs.mean_axis(Axis(0)).into_scalar(); - let standard_deviation = inputs.std_axis(Axis(0), blueprint.ddof).into_scalar(); - // Initialize n_samples using the array length - self.n_samples = inputs.len() as u64; - // Return new, tuned scaler - Ok(StandardScaler { - ddof: blueprint.ddof, - mean, - standard_deviation, - }) - } -} - -impl IncrementalFit, Output> for OnlineOptimizer -where - S: Data, -{ - type Error = ScalingError; - - fn incremental_fit( - &mut self, - inputs: &Input, - _targets: &Output, - transformer: StandardScaler, - ) -> Result { - if inputs.len() == 0 { - // Nothing to be done - return Ok(transformer); - } - // Compute relevant quantities for the new batch - let batch_n_samples = inputs.len(); - let batch_mean = inputs.mean_axis(Axis(0)).into_scalar(); - let batch_std = inputs.std_axis(Axis(0), transformer.ddof).into_scalar(); - - // Update - let mean_delta = batch_mean - transformer.mean; - let new_n_samples = self.n_samples + (batch_n_samples as u64); - let new_mean = - transformer.mean + mean_delta * (batch_n_samples as f64) / (new_n_samples as f64); - let new_std = transformer.standard_deviation - + batch_std - + mean_delta.powi(2) * (self.n_samples as f64) * (batch_n_samples as f64) - / (new_n_samples as f64); - - // Update n_samples - self.n_samples = new_n_samples; - - // Return tuned scaler - Ok(StandardScaler { - ddof: transformer.ddof, - mean: new_mean, - standard_deviation: new_std, - }) - } -} - -pub struct Config { - // Delta degrees of freedom. - // With ddof = 1, you get the sample standard deviation - // With ddof = 0, you get the population standard deviation - ddof: f64, -} - -/// Defaults to computing the sample standard deviation. -impl Default for Config { - fn default() -> Self { - Self { ddof: 1. } - } -} - -impl Blueprint, Output> for Config -where - S: Data, -{ - type Transformer = StandardScaler; -} - -impl Transformer, Output> for StandardScaler -where - S: Data, -{ - type Error = ScalingError; - - fn transform(&self, inputs: &Input) -> Result - where - S: Data, - { - Ok((inputs - self.mean) / self.standard_deviation) - } -} +mod standard_scaler; fn generate_batch(n_samples: usize) -> (Array1, Array1) { let distribution = Uniform::new(0., 10.); @@ -176,6 +36,7 @@ where Ok(()) } +/// Run it with: cargo run --example running_mean fn main() -> Result<(), ScalingError> { let n_samples = 20; let (x, y) = generate_batch(n_samples); diff --git a/examples/running_mean/standard_scaler/config.rs b/examples/running_mean/standard_scaler/config.rs new file mode 100644 index 0000000..eac3d5f --- /dev/null +++ b/examples/running_mean/standard_scaler/config.rs @@ -0,0 +1,25 @@ +use crate::standard_scaler::{Input, Output, StandardScaler}; +use linfa::Blueprint; +use ndarray::Data; + +pub struct Config { + // Delta degrees of freedom. + // With ddof = 1, you get the sample standard deviation + // With ddof = 0, you get the population standard deviation + pub ddof: f64, +} + +/// Defaults to computing the sample standard deviation. +impl Default for Config { + fn default() -> Self { + Self { ddof: 1. } + } +} + +impl Blueprint, Output> for Config + where + S: Data, +{ + type Transformer = StandardScaler; +} + diff --git a/examples/running_mean/standard_scaler/mod.rs b/examples/running_mean/standard_scaler/mod.rs new file mode 100644 index 0000000..098ce9b --- /dev/null +++ b/examples/running_mean/standard_scaler/mod.rs @@ -0,0 +1,45 @@ +use ndarray::{Array1, ArrayBase, Ix1, Data}; +use linfa::Transformer; +use std::error::Error; + +/// Short-hand notations +type Input = ArrayBase; +type Output = Array1; + +/// Given an input, it rescales it to have zero mean and unit variance. +/// +/// We use 64-bit floats for simplicity. +pub struct StandardScaler { + // Delta degrees of freedom. + // With ddof = 1, you get the sample standard deviation + // With ddof = 0, you get the population standard deviation + pub ddof: f64, + pub mean: f64, + pub standard_deviation: f64, +} + +/// Fast-and-dirty error struct +#[derive(Debug, Eq, PartialEq, From, Display)] +pub struct ScalingError {} + +impl Error for ScalingError {} + +impl Transformer, Output> for StandardScaler + where + S: Data, +{ + type Error = ScalingError; + + fn transform(&self, inputs: &Input) -> Result + where + S: Data, + { + Ok((inputs - self.mean) / self.standard_deviation) + } +} + +mod config; +mod optimizer; + +pub use config::Config; +pub use optimizer::OnlineOptimizer; \ No newline at end of file diff --git a/examples/running_mean/standard_scaler/optimizer.rs b/examples/running_mean/standard_scaler/optimizer.rs new file mode 100644 index 0000000..8f1d061 --- /dev/null +++ b/examples/running_mean/standard_scaler/optimizer.rs @@ -0,0 +1,89 @@ +use crate::standard_scaler::{Config, StandardScaler, Input, Output, ScalingError}; +use linfa::{IncrementalFit, Fit}; +use ndarray::{Data, Axis}; + +/// It keeps track of the number of samples seen so far, to allow for +/// incremental computation of mean and standard deviation. +pub struct OnlineOptimizer { + pub n_samples: u64, +} + +/// Initialize n_samples to 0. +impl Default for OnlineOptimizer { + fn default() -> Self { + Self { n_samples: 0 } + } +} + +impl Fit, Output> for OnlineOptimizer + where + S: Data, +{ + type Error = ScalingError; + + fn fit( + &mut self, + inputs: &Input, + _targets: &Output, + blueprint: Config, + ) -> Result { + if inputs.len() == 0 { + return Err(ScalingError {}); + } + // Compute relevant quantities + let mean = inputs.mean_axis(Axis(0)).into_scalar(); + let standard_deviation = inputs.std_axis(Axis(0), blueprint.ddof).into_scalar(); + // Initialize n_samples using the array length + self.n_samples = inputs.len() as u64; + // Return new, tuned scaler + Ok(StandardScaler { + ddof: blueprint.ddof, + mean, + standard_deviation, + }) + } +} + +impl IncrementalFit, Output> for OnlineOptimizer + where + S: Data, +{ + type Error = ScalingError; + + fn incremental_fit( + &mut self, + inputs: &Input, + _targets: &Output, + transformer: StandardScaler, + ) -> Result { + if inputs.len() == 0 { + // Nothing to be done + return Ok(transformer); + } + // Compute relevant quantities for the new batch + let batch_n_samples = inputs.len(); + let batch_mean = inputs.mean_axis(Axis(0)).into_scalar(); + let batch_std = inputs.std_axis(Axis(0), transformer.ddof).into_scalar(); + + // Update + let mean_delta = batch_mean - transformer.mean; + let new_n_samples = self.n_samples + (batch_n_samples as u64); + let new_mean = + transformer.mean + mean_delta * (batch_n_samples as f64) / (new_n_samples as f64); + let new_std = transformer.standard_deviation + + batch_std + + mean_delta.powi(2) * (self.n_samples as f64) * (batch_n_samples as f64) + / (new_n_samples as f64); + + // Update n_samples + self.n_samples = new_n_samples; + + // Return tuned scaler + Ok(StandardScaler { + ddof: transformer.ddof, + mean: new_mean, + standard_deviation: new_std, + }) + } +} + From 3d23c4b9803b2df67ae4cbf4f074d3a68d1a80b5 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 17:58:15 +0100 Subject: [PATCH 33/35] Restructure into a proper module --- examples/running_mean/main.rs | 11 ++--- .../running_mean/standard_scaler/config.rs | 5 +-- .../running_mean/standard_scaler/error.rs | 7 ++++ examples/running_mean/standard_scaler/mod.rs | 42 +++---------------- .../running_mean/standard_scaler/optimizer.rs | 17 ++++---- .../standard_scaler/transformer.rs | 29 +++++++++++++ 6 files changed, 56 insertions(+), 55 deletions(-) create mode 100644 examples/running_mean/standard_scaler/error.rs create mode 100644 examples/running_mean/standard_scaler/transformer.rs diff --git a/examples/running_mean/main.rs b/examples/running_mean/main.rs index a07ca2a..fdfff34 100644 --- a/examples/running_mean/main.rs +++ b/examples/running_mean/main.rs @@ -5,11 +5,11 @@ extern crate rand; #[macro_use] extern crate derive_more; +use crate::standard_scaler::{Config, OnlineOptimizer, ScalingError, StandardScaler}; use linfa::{Fit, IncrementalFit, Transformer}; -use ndarray::{Array1, ArrayBase, Axis, Data, Ix1, stack}; +use ndarray::{stack, Array1, ArrayBase, Axis, Data, Ix1}; use ndarray_rand::RandomExt; use rand::distributions::Uniform; -use crate::standard_scaler::{StandardScaler, ScalingError, Config, OnlineOptimizer}; mod standard_scaler; @@ -22,13 +22,10 @@ fn generate_batch(n_samples: usize) -> (Array1, Array1) { fn check(scaler: &StandardScaler, x: &ArrayBase) -> Result<(), ScalingError> where - S: Data + S: Data, { let old_batch_mean = x.mean_axis(Axis(0)).into_scalar(); - let new_batch_mean = scaler - .transform(&x)? - .mean_axis(Axis(0)) - .into_scalar(); + let new_batch_mean = scaler.transform(&x)?.mean_axis(Axis(0)).into_scalar(); println!( "The mean.\nBefore scaling: {:?}\nAfter scaling: {:?}\n", old_batch_mean, new_batch_mean diff --git a/examples/running_mean/standard_scaler/config.rs b/examples/running_mean/standard_scaler/config.rs index eac3d5f..8cb33ba 100644 --- a/examples/running_mean/standard_scaler/config.rs +++ b/examples/running_mean/standard_scaler/config.rs @@ -17,9 +17,8 @@ impl Default for Config { } impl Blueprint, Output> for Config - where - S: Data, +where + S: Data, { type Transformer = StandardScaler; } - diff --git a/examples/running_mean/standard_scaler/error.rs b/examples/running_mean/standard_scaler/error.rs new file mode 100644 index 0000000..cb828ae --- /dev/null +++ b/examples/running_mean/standard_scaler/error.rs @@ -0,0 +1,7 @@ +use std::error::Error; + +/// Fast-and-dirty error struct +#[derive(Debug, Eq, PartialEq, From, Display)] +pub struct ScalingError {} + +impl Error for ScalingError {} diff --git a/examples/running_mean/standard_scaler/mod.rs b/examples/running_mean/standard_scaler/mod.rs index 098ce9b..da91ba9 100644 --- a/examples/running_mean/standard_scaler/mod.rs +++ b/examples/running_mean/standard_scaler/mod.rs @@ -1,45 +1,15 @@ -use ndarray::{Array1, ArrayBase, Ix1, Data}; -use linfa::Transformer; -use std::error::Error; +use ndarray::{Array1, ArrayBase, Ix1}; /// Short-hand notations type Input = ArrayBase; type Output = Array1; -/// Given an input, it rescales it to have zero mean and unit variance. -/// -/// We use 64-bit floats for simplicity. -pub struct StandardScaler { - // Delta degrees of freedom. - // With ddof = 1, you get the sample standard deviation - // With ddof = 0, you get the population standard deviation - pub ddof: f64, - pub mean: f64, - pub standard_deviation: f64, -} - -/// Fast-and-dirty error struct -#[derive(Debug, Eq, PartialEq, From, Display)] -pub struct ScalingError {} - -impl Error for ScalingError {} - -impl Transformer, Output> for StandardScaler - where - S: Data, -{ - type Error = ScalingError; - - fn transform(&self, inputs: &Input) -> Result - where - S: Data, - { - Ok((inputs - self.mean) / self.standard_deviation) - } -} - mod config; +mod error; mod optimizer; +mod transformer; pub use config::Config; -pub use optimizer::OnlineOptimizer; \ No newline at end of file +pub use error::ScalingError; +pub use optimizer::OnlineOptimizer; +pub use transformer::StandardScaler; diff --git a/examples/running_mean/standard_scaler/optimizer.rs b/examples/running_mean/standard_scaler/optimizer.rs index 8f1d061..b02a06a 100644 --- a/examples/running_mean/standard_scaler/optimizer.rs +++ b/examples/running_mean/standard_scaler/optimizer.rs @@ -1,6 +1,6 @@ -use crate::standard_scaler::{Config, StandardScaler, Input, Output, ScalingError}; -use linfa::{IncrementalFit, Fit}; -use ndarray::{Data, Axis}; +use crate::standard_scaler::{Config, Input, Output, ScalingError, StandardScaler}; +use linfa::{Fit, IncrementalFit}; +use ndarray::{Axis, Data}; /// It keeps track of the number of samples seen so far, to allow for /// incremental computation of mean and standard deviation. @@ -16,8 +16,8 @@ impl Default for OnlineOptimizer { } impl Fit, Output> for OnlineOptimizer - where - S: Data, +where + S: Data, { type Error = ScalingError; @@ -45,8 +45,8 @@ impl Fit, Output> for OnlineOptimizer } impl IncrementalFit, Output> for OnlineOptimizer - where - S: Data, +where + S: Data, { type Error = ScalingError; @@ -73,7 +73,7 @@ impl IncrementalFit, Output> for OnlineOptimizer let new_std = transformer.standard_deviation + batch_std + mean_delta.powi(2) * (self.n_samples as f64) * (batch_n_samples as f64) - / (new_n_samples as f64); + / (new_n_samples as f64); // Update n_samples self.n_samples = new_n_samples; @@ -86,4 +86,3 @@ impl IncrementalFit, Output> for OnlineOptimizer }) } } - diff --git a/examples/running_mean/standard_scaler/transformer.rs b/examples/running_mean/standard_scaler/transformer.rs new file mode 100644 index 0000000..dcae233 --- /dev/null +++ b/examples/running_mean/standard_scaler/transformer.rs @@ -0,0 +1,29 @@ +use crate::standard_scaler::{Input, Output, ScalingError}; +use linfa::Transformer; +use ndarray::Data; + +/// Given an input, it rescales it to have zero mean and unit variance. +/// +/// We use 64-bit floats for simplicity. +pub struct StandardScaler { + // Delta degrees of freedom. + // With ddof = 1, you get the sample standard deviation + // With ddof = 0, you get the population standard deviation + pub ddof: f64, + pub mean: f64, + pub standard_deviation: f64, +} + +impl Transformer, Output> for StandardScaler +where + S: Data, +{ + type Error = ScalingError; + + fn transform(&self, inputs: &Input) -> Result + where + S: Data, + { + Ok((inputs - self.mean) / self.standard_deviation) + } +} From a29de613b85c6dbb39c3381a65a68bb16a290b8c Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 18:13:47 +0100 Subject: [PATCH 34/35] Fix stdd update --- examples/running_mean/main.rs | 6 ++++++ examples/running_mean/standard_scaler/optimizer.rs | 9 ++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/running_mean/main.rs b/examples/running_mean/main.rs index fdfff34..1edc81e 100644 --- a/examples/running_mean/main.rs +++ b/examples/running_mean/main.rs @@ -26,10 +26,16 @@ where { let old_batch_mean = x.mean_axis(Axis(0)).into_scalar(); let new_batch_mean = scaler.transform(&x)?.mean_axis(Axis(0)).into_scalar(); + let old_batch_std = x.std_axis(Axis(0), 1.).into_scalar(); + let new_batch_std = scaler.transform(&x)?.std_axis(Axis(0), 1.).into_scalar(); println!( "The mean.\nBefore scaling: {:?}\nAfter scaling: {:?}\n", old_batch_mean, new_batch_mean ); + println!( + "The std deviation.\nBefore scaling: {:?}\nAfter scaling: {:?}\n", + old_batch_std, new_batch_std + ); Ok(()) } diff --git a/examples/running_mean/standard_scaler/optimizer.rs b/examples/running_mean/standard_scaler/optimizer.rs index b02a06a..c8b19fb 100644 --- a/examples/running_mean/standard_scaler/optimizer.rs +++ b/examples/running_mean/standard_scaler/optimizer.rs @@ -70,10 +70,13 @@ where let new_n_samples = self.n_samples + (batch_n_samples as u64); let new_mean = transformer.mean + mean_delta * (batch_n_samples as f64) / (new_n_samples as f64); - let new_std = transformer.standard_deviation - + batch_std + let new_std = ((transformer.standard_deviation.powi(2) + * (self.n_samples as f64 - transformer.ddof) + + batch_std.powi(2) * (batch_n_samples as f64 - transformer.ddof) + mean_delta.powi(2) * (self.n_samples as f64) * (batch_n_samples as f64) - / (new_n_samples as f64); + / (new_n_samples as f64)) + / (new_n_samples as f64 - transformer.ddof)) + .sqrt(); // Update n_samples self.n_samples = new_n_samples; From eef5f6f9b0d77f3ce9f82b370af47018a39bc12e Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 19 May 2019 18:19:13 +0100 Subject: [PATCH 35/35] Clean up code for stdd update --- examples/running_mean/standard_scaler/optimizer.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/running_mean/standard_scaler/optimizer.rs b/examples/running_mean/standard_scaler/optimizer.rs index c8b19fb..2227939 100644 --- a/examples/running_mean/standard_scaler/optimizer.rs +++ b/examples/running_mean/standard_scaler/optimizer.rs @@ -60,22 +60,24 @@ where // Nothing to be done return Ok(transformer); } + + let ddof = transformer.ddof; + // Compute relevant quantities for the new batch let batch_n_samples = inputs.len(); let batch_mean = inputs.mean_axis(Axis(0)).into_scalar(); - let batch_std = inputs.std_axis(Axis(0), transformer.ddof).into_scalar(); + let batch_std = inputs.std_axis(Axis(0), ddof).into_scalar(); // Update let mean_delta = batch_mean - transformer.mean; let new_n_samples = self.n_samples + (batch_n_samples as u64); let new_mean = transformer.mean + mean_delta * (batch_n_samples as f64) / (new_n_samples as f64); - let new_std = ((transformer.standard_deviation.powi(2) - * (self.n_samples as f64 - transformer.ddof) - + batch_std.powi(2) * (batch_n_samples as f64 - transformer.ddof) + let new_std = ((transformer.standard_deviation.powi(2) * (self.n_samples as f64 - ddof) + + batch_std.powi(2) * (batch_n_samples as f64 - ddof) + mean_delta.powi(2) * (self.n_samples as f64) * (batch_n_samples as f64) / (new_n_samples as f64)) - / (new_n_samples as f64 - transformer.ddof)) + / (new_n_samples as f64 - ddof)) .sqrt(); // Update n_samples @@ -83,7 +85,7 @@ where // Return tuned scaler Ok(StandardScaler { - ddof: transformer.ddof, + ddof, mean: new_mean, standard_deviation: new_std, })