From 308663b01fa06b40099e9065095117f631def628 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 12 Jul 2024 11:51:43 -0400 Subject: [PATCH 01/18] Updated some docs --- R/bart.R | 2 +- README.md | 4 ++-- src/random_effects.cpp | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/R/bart.R b/R/bart.R index c9c9e050..2ccebace 100644 --- a/R/bart.R +++ b/R/bart.R @@ -356,7 +356,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, if (verbose) { if (num_burnin > 0) { if (((i - num_gfr) %% 100 == 0) || ((i - num_gfr) == num_burnin)) { - cat("Sampling", i - num_gfr, "out of", num_gfr, "BART burn-in draws\n") + cat("Sampling", i - num_gfr, "out of", num_burnin, "BART burn-in draws\n") } } if (num_mcmc > 0) { diff --git a/README.md b/README.md index c25e5503..644e5081 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# StochasticTree +# StochTree [![C++ Tests](https://github.com/StochasticTree/stochtree/actions/workflows/cpp-test.yml/badge.svg)](https://github.com/StochasticTree/stochtree/actions/workflows/cpp-test.yml) [![Python Tests](https://github.com/StochasticTree/stochtree/actions/workflows/python-test.yml/badge.svg)](https://github.com/StochasticTree/stochtree/actions/workflows/python-test.yml) @@ -8,7 +8,7 @@ Software for building stochastic tree ensembles (i.e. BART, XBART) for supervise # Getting Started -`StochasticTree` is composed of a C++ "core" and R / Python interfaces to that core. +`stochtree` is composed of a C++ "core" and R / Python interfaces to that core. Details on installation and use are available below: * [Python](#python-package) diff --git a/src/random_effects.cpp b/src/random_effects.cpp index 19108d48..bc746e81 100644 --- a/src/random_effects.cpp +++ b/src/random_effects.cpp @@ -1,4 +1,4 @@ -/*! Copyright (c) 2024 StochasticTree authors */ +/*! Copyright (c) 2024 stochtree authors */ #include namespace StochTree { From d450b552f06244ff7d9db7d36bd3c88b273dc62e Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 12 Jul 2024 15:01:04 -0400 Subject: [PATCH 02/18] Updated with a C++ only version of the classical versions of BART/XBART (constant leaves, no random effects, no leaf variance sampling) --- CMakeLists.txt | 1 + debug/api_debug.cpp | 182 ++++++++++++++++++++++++++++++---- include/stochtree/bart.h | 76 ++++++++++++++ include/stochtree/container.h | 3 + src/bart.cpp | 117 ++++++++++++++++++++++ src/container.cpp | 35 +++++-- 6 files changed, 386 insertions(+), 28 deletions(-) create mode 100644 include/stochtree/bart.h create mode 100644 src/bart.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c9c0796..b219d00c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,6 +94,7 @@ set(LIBRARY_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/build) file( GLOB SOURCES + src/bart.cpp src/container.cpp src/cutpoint_candidates.cpp src/data.cpp diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index a8753f35..0187bf58 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -1,4 +1,5 @@ /*! Copyright (c) 2024 stochtree authors*/ +#include #include #include #include @@ -265,6 +266,27 @@ void OutcomeOffsetScale(ColumnVector& residual, double& outcome_offset, double& } } +void OutcomeOffsetScale(std::vector& residual, double& outcome_offset, double& outcome_scale) { + data_size_t n = residual.size(); + double outcome_val = 0.0; + double outcome_sum = 0.0; + double outcome_sum_squares = 0.0; + double var_y = 0.0; + for (data_size_t i = 0; i < n; i++){ + outcome_val = residual.at(i); + outcome_sum += outcome_val; + outcome_sum_squares += std::pow(outcome_val, 2.0); + } + var_y = outcome_sum_squares / static_cast(n) - std::pow(outcome_sum / static_cast(n), 2.0); + outcome_scale = std::sqrt(var_y); + outcome_offset = outcome_sum / static_cast(n); + double previous_residual; + for (data_size_t i = 0; i < n; i++){ + previous_residual = residual.at(i); + residual.at(i) = (previous_residual - outcome_offset) / outcome_scale; + } +} + void sampleGFR(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, ColumnVector& residual, std::mt19937& rng, std::vector& feature_types, std::vector& var_weights_vector, ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { @@ -301,7 +323,7 @@ void sampleMCMC(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& } } -void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_mcmc = 100, int random_seed = -1) { +void RunDebugDeconstructed(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_burnin = 0, int num_mcmc = 100, int random_seed = -1) { // Flag the data as row-major bool row_major = true; @@ -528,32 +550,150 @@ void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int std::vector pred_parsed = forest_samples_parsed.Predict(dataset); } -} // namespace StochTree +void RunDebugLoop(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_burnin = 0, int num_mcmc = 100, int random_seed = -1) { + // Flag the data as row-major + bool row_major = true; -int main(int argc, char* argv[]) { - // Unpack command line arguments - int dgp_num = std::stoi(argv[1]); - if ((dgp_num != 0) && (dgp_num != 1)) { - StochTree::Log::Fatal("The first command line argument must be 0 or 1"); + // Random number generation + std::mt19937 gen; + if (random_seed == -1) { + std::random_device rd; + std::mt19937 gen(rd()); } - int rfx_int = std::stoi(argv[2]); - if ((rfx_int != 0) && (rfx_int != 1)) { - StochTree::Log::Fatal("The second command line argument must be 0 or 1"); + else { + std::mt19937 gen(random_seed); } - bool rfx_included = static_cast(rfx_int); - int num_gfr = std::stoi(argv[3]); - if (num_gfr < 0) { - StochTree::Log::Fatal("The third command line argument must be >= 0"); + + // Empty data containers and dimensions (filled in by calling a specific DGP simulation function below) + int n; + int x_cols; + int omega_cols; + int y_cols; + int num_rfx_groups; + int rfx_basis_cols; + std::vector covariates_raw; + std::vector basis_raw; + std::vector rfx_basis_raw; + std::vector residual_raw; + std::vector rfx_groups; + std::vector feature_types; + + // Generate the data + int output_dimension; + bool is_leaf_constant; + ForestLeafModel leaf_model_type; + if (dgp_num == 0) { + GenerateDGP1(covariates_raw, basis_raw, residual_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); + output_dimension = 1; + is_leaf_constant = true; + leaf_model_type = ForestLeafModel::kConstant; + } + else if (dgp_num == 1) { + GenerateDGP2(covariates_raw, basis_raw, residual_raw, rfx_basis_raw, rfx_groups, feature_types, gen, n, x_cols, omega_cols, y_cols, rfx_basis_cols, num_rfx_groups, rfx_included, random_seed); + output_dimension = 1; + is_leaf_constant = true; + leaf_model_type = ForestLeafModel::kConstant; } - int num_mcmc = std::stoi(argv[4]); - if (num_mcmc < 0) { - StochTree::Log::Fatal("The fourth command line argument must be >= 0"); + else { + Log::Fatal("Invalid dgp_num"); } - int random_seed = std::stoi(argv[5]); - if (random_seed < -1) { - StochTree::Log::Fatal("The fifth command line argument must be >= -0"); + + // Center and scale the data + double outcome_offset; + double outcome_scale; + OutcomeOffsetScale(residual_raw, outcome_offset, outcome_scale); + + // Construct loop sampling objects (override is_leaf_constant if necessary) + int num_trees = 50; + output_dimension = 1; + is_leaf_constant = true; + BARTDispatcher bart_dispatcher{}; + BARTResult bart_result = bart_dispatcher.CreateOutputObject(num_trees, output_dimension, is_leaf_constant); + + // Add covariates to sampling loop + bart_dispatcher.AddDataset(covariates_raw.data(), n, x_cols, row_major, true); + + // Add outcome to sampling loop + bart_dispatcher.AddTrainOutcome(residual_raw.data(), n); + + // Forest sampling parameters + double alpha = 0.9; + double beta = 2; + int min_samples_leaf = 1; + int cutpoint_grid_size = 100; + double a_leaf = 3.; + double b_leaf = 0.5 / num_trees; + double nu = 3.; + double lamb = 0.5; + double leaf_variance_init = 1. / num_trees; + double global_variance_init = 1.0; + + // Set variable weights + double const_var_wt = static_cast(1. / x_cols); + std::vector variable_weights(x_cols, const_var_wt); + + // Run the BART sampling loop + bart_dispatcher.RunSampler(bart_result, feature_types, variable_weights, num_trees, num_gfr, num_burnin, num_mcmc, + global_variance_init, leaf_variance_init, alpha, beta, nu, lamb, a_leaf, b_leaf, + min_samples_leaf, cutpoint_grid_size); +} + +void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_burnin = 0, int num_mcmc = 100, int random_seed = -1, bool run_bart_loop = true) { + if (run_bart_loop) { + RunDebugLoop(dgp_num, rfx_included, num_gfr, num_burnin, num_mcmc, random_seed); + } else { + RunDebugDeconstructed(dgp_num, rfx_included, num_gfr, num_burnin, num_mcmc, random_seed); + } +} + +} // namespace StochTree + +int main(int argc, char* argv[]) { + int dgp_num, num_gfr, num_burnin, num_mcmc, random_seed; + bool rfx_included, run_bart_loop; + if (argc > 1) { + if (argc < 8) StochTree::Log::Fatal("Must provide 7 command line arguments"); + // Unpack command line arguments + dgp_num = std::stoi(argv[1]); + if ((dgp_num != 0) && (dgp_num != 1)) { + StochTree::Log::Fatal("The first command line argument must be 0 or 1"); + } + int rfx_int = std::stoi(argv[2]); + if ((rfx_int != 0) && (rfx_int != 1)) { + StochTree::Log::Fatal("The second command line argument must be 0 or 1"); + } + rfx_included = static_cast(rfx_int); + num_gfr = std::stoi(argv[3]); + if (num_gfr < 0) { + StochTree::Log::Fatal("The third command line argument must be >= 0"); + } + num_burnin = std::stoi(argv[4]); + if (num_burnin < 0) { + StochTree::Log::Fatal("The fourth command line argument must be >= 0"); + } + num_mcmc = std::stoi(argv[5]); + if (num_mcmc < 0) { + StochTree::Log::Fatal("The fifth command line argument must be >= 0"); + } + random_seed = std::stoi(argv[6]); + if (random_seed < -1) { + StochTree::Log::Fatal("The sixth command line argument must be >= -1"); + } + int run_bart_loop_int = std::stoi(argv[7]); + if ((run_bart_loop_int != 0) && (run_bart_loop_int != 1)) { + StochTree::Log::Fatal("The seventh command line argument must be 0 or 1"); + } + run_bart_loop = static_cast(run_bart_loop_int); + } else { + dgp_num = 1; + rfx_included = false; + num_gfr = 10; + num_burnin = 0; + num_mcmc = 10; + random_seed = -1; + run_bart_loop = true; } // Run the debug program - StochTree::RunDebug(dgp_num, rfx_included, num_gfr, num_mcmc); + StochTree::RunDebug(dgp_num, rfx_included, num_gfr, num_burnin, num_mcmc, random_seed, run_bart_loop); } diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h new file mode 100644 index 00000000..26469cf1 --- /dev/null +++ b/include/stochtree/bart.h @@ -0,0 +1,76 @@ +/*! Copyright (c) 2024 stochtree authors. */ +#ifndef STOCHTREE_BART_H_ +#define STOCHTREE_BART_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace StochTree { + +class BARTResult { + public: + BARTResult(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) : + forests_samples_{num_trees, output_dimension, is_leaf_constant} {} + ~BARTResult() {} + ForestContainer& GetForests() {return forests_samples_;} + std::vector& GetTrainPreds() {return raw_preds_train_;} + std::vector& GetTestPreds() {return raw_preds_test_;} + std::vector& GetVarianceSamples() {return sigma_samples_;} + int NumGFRSamples() {return num_gfr_;} + int NumBurninSamples() {return num_burnin_;} + int NumMCMCSamples() {return num_mcmc_;} + int NumTrainObservations() {return num_train_;} + int NumTestObservations() {return num_test_;} + bool HasTestSet() {return has_test_set_;} + private: + ForestContainer forests_samples_; + std::vector raw_preds_train_; + std::vector raw_preds_test_; + std::vector sigma_samples_; + int num_gfr_{0}; + int num_burnin_{0}; + int num_mcmc_{0}; + int num_train_{0}; + int num_test_{0}; + bool has_test_set_{false}; +}; + +class BARTDispatcher { + public: + BARTDispatcher() {} + ~BARTDispatcher() {} + BARTResult CreateOutputObject(int num_trees, int output_dimension = 1, bool is_leaf_constant = true); + void RunSampler( + BARTResult& output, std::vector& feature_types, std::vector& variable_weights, + int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, double leaf_var_init, + double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, + int cutpoint_grid_size, int random_seed = -1 + ); + void AddDataset(double* covariates, data_size_t num_row, int num_col, bool is_row_major, bool train); + void AddTrainOutcome(double* outcome, data_size_t num_row); + private: + // Sampling details + int num_gfr_{0}; + int num_burnin_{0}; + int num_mcmc_{0}; + int num_train_{0}; + int num_test_{0}; + bool has_test_set_{false}; + + // Sampling data objects + ForestDataset train_dataset_; + ForestDataset test_dataset_; + ColumnVector train_outcome_; +}; + +} // namespace StochTree + +#endif // STOCHTREE_SAMPLING_DISPATCH_H_ diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 1af8dd2b..880f7564 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -33,6 +33,9 @@ class ForestContainer { std::vector Predict(ForestDataset& dataset); std::vector PredictRaw(ForestDataset& dataset); std::vector PredictRaw(ForestDataset& dataset, int forest_num); + void PredictInPlace(ForestDataset& dataset, std::vector& output); + void PredictRawInPlace(ForestDataset& dataset, std::vector& output); + void PredictRawInPlace(ForestDataset& dataset, int forest_num, std::vector& output); inline TreeEnsemble* GetEnsemble(int i) {return forests_[i].get();} inline int32_t NumSamples() {return num_samples_;} diff --git a/src/bart.cpp b/src/bart.cpp new file mode 100644 index 00000000..a1dc4ad8 --- /dev/null +++ b/src/bart.cpp @@ -0,0 +1,117 @@ +/*! Copyright (c) 2024 by stochtree authors */ +#include + +namespace StochTree { + +BARTResult BARTDispatcher::CreateOutputObject(int num_trees, int output_dimension, bool is_leaf_constant) { + return BARTResult(num_trees, output_dimension, is_leaf_constant); +} + +void BARTDispatcher::AddDataset(double* covariates, data_size_t num_row, int num_col, bool is_row_major, bool train) { + if (train) { + train_dataset_ = ForestDataset(); + train_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); + num_train_ = num_row; + } else { + test_dataset_ = ForestDataset(); + test_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); + has_test_set_ = true; + num_test_ = num_row; + } +} + +void BARTDispatcher::AddTrainOutcome(double* outcome, data_size_t num_row) { + train_outcome_ = ColumnVector(); + train_outcome_.LoadData(outcome, num_row); +} + +void BARTDispatcher::RunSampler( + BARTResult& output, std::vector& feature_types, std::vector& variable_weights, + int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, double leaf_var_init, + double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, + int cutpoint_grid_size, int random_seed +) { + // Unpack sampling details + num_gfr_ = num_gfr; + num_burnin_ = num_burnin; + num_mcmc_ = num_mcmc; + int num_samples = num_gfr + num_burnin + num_mcmc; + + // Random number generation + std::mt19937 rng; + if (random_seed == -1) { + std::random_device rd; + std::mt19937 rng(rd()); + } + else { + std::mt19937 rng(random_seed); + } + + // Obtain references to forest / parameter samples and predictions in BARTResult + ForestContainer& forest_samples = output.GetForests(); + std::vector& sigma2_samples = output.GetVarianceSamples(); + std::vector& train_preds = output.GetTrainPreds(); + std::vector& test_preds = output.GetTestPreds(); + + // Clear and prepare vectors to store results + sigma2_samples.clear(); + train_preds.clear(); + test_preds.clear(); + sigma2_samples.resize(num_samples); + train_preds.resize(num_samples*num_train_); + if (has_test_set_) test_preds.resize(num_samples*num_test_); + + // Initialize tracker and tree prior + ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_); + TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf); + + // Initialize variance model + GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); + + // Initialize leaf model and samplers + GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_var_init); + GFRForestSampler gfr_sampler = GFRForestSampler(cutpoint_grid_size); + MCMCForestSampler mcmc_sampler = MCMCForestSampler(); + + // Running variable for current sampled value of global outcome variance parameter + double global_var = global_var_init; + + // Run the XBART Gibbs sampler + int iter = 0; + if (num_gfr > 0) { + for (int i = 0; i < num_gfr; i++) { + // Sample the forests + gfr_sampler.SampleOneIter(tracker, forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, + rng, variable_weights, global_var, feature_types, false); + + // Sample the global outcome + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + + // Increment sample counter + iter++; + } + } + + // Run the MCMC sampler + if (num_burnin + num_mcmc > 0) { + for (int i = 0; i < num_burnin + num_mcmc; i++) { + // Sample the forests + mcmc_sampler.SampleOneIter(tracker, forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, + rng, variable_weights, global_var, true); + + // Sample the global outcome + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + + // Increment sample counter + iter++; + } + } + + // Predict forests + forest_samples.PredictInPlace(train_dataset_, train_preds); + if (has_test_set_) forest_samples.PredictInPlace(test_dataset_, test_preds); +} + +} // namespace StochTree diff --git a/src/container.cpp b/src/container.cpp index 747e6995..4f7b251b 100644 --- a/src/container.cpp +++ b/src/container.cpp @@ -68,36 +68,57 @@ std::vector ForestContainer::Predict(ForestDataset& dataset) { data_size_t n = dataset.NumObservations(); data_size_t total_output_size = n*num_samples_; std::vector output(total_output_size); + PredictInPlace(dataset, output); + return output; +} + +std::vector ForestContainer::PredictRaw(ForestDataset& dataset) { + data_size_t n = dataset.NumObservations(); + data_size_t total_output_size = n * output_dimension_ * num_samples_; + std::vector output(total_output_size); + PredictRawInPlace(dataset, output); + return output; +} + +std::vector ForestContainer::PredictRaw(ForestDataset& dataset, int forest_num) { + data_size_t n = dataset.NumObservations(); + data_size_t total_output_size = n * output_dimension_; + std::vector output(total_output_size); + PredictRawInPlace(dataset, forest_num, output); + return output; +} + +void ForestContainer::PredictInPlace(ForestDataset& dataset, std::vector& output) { + data_size_t n = dataset.NumObservations(); + data_size_t total_output_size = n*num_samples_; + CHECK_EQ(total_output_size, output.size()); data_size_t offset = 0; for (int i = 0; i < num_samples_; i++) { auto num_trees = forests_[i]->NumTrees(); forests_[i]->PredictInplace(dataset, output, 0, num_trees, offset); offset += n; } - return output; } -std::vector ForestContainer::PredictRaw(ForestDataset& dataset) { +void ForestContainer::PredictRawInPlace(ForestDataset& dataset, std::vector& output) { data_size_t n = dataset.NumObservations(); data_size_t total_output_size = n * output_dimension_ * num_samples_; - std::vector output(total_output_size); + CHECK_EQ(total_output_size, output.size()); data_size_t offset = 0; for (int i = 0; i < num_samples_; i++) { auto num_trees = forests_[i]->NumTrees(); forests_[i]->PredictRawInplace(dataset, output, 0, num_trees, offset); offset += n * output_dimension_; } - return output; } -std::vector ForestContainer::PredictRaw(ForestDataset& dataset, int forest_num) { +void ForestContainer::PredictRawInPlace(ForestDataset& dataset, int forest_num, std::vector& output) { data_size_t n = dataset.NumObservations(); data_size_t total_output_size = n * output_dimension_; - std::vector output(total_output_size); + CHECK_EQ(total_output_size, output.size()); data_size_t offset = 0; auto num_trees = forests_[forest_num]->NumTrees(); forests_[forest_num]->PredictRawInplace(dataset, output, 0, num_trees, offset); - return output; } /*! \brief Save to JSON */ From e607c79efd9686e37234f6addd093fbce3bd98e0 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 12 Jul 2024 15:20:38 -0400 Subject: [PATCH 03/18] Fixed include issue --- include/stochtree/bart.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index 26469cf1..9b87db7d 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -3,7 +3,6 @@ #define STOCHTREE_BART_H_ #include -#include #include #include #include From 861f4bc8c609034fb1626157b2e5cfe337581a6a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 12 Jul 2024 18:16:55 -0400 Subject: [PATCH 04/18] Wrapped C++ sampling loop in R function --- NAMESPACE | 1 + R/bart.R | 241 ++++++++++++++++++++++++++++++++ R/cpp11.R | 4 + man/bart_specialized.Rd | 118 ++++++++++++++++ man/preprocessTrainDataFrame.Rd | 2 - man/preprocessTrainMatrix.Rd | 2 - src/Makevars | 2 + src/R_bart.cpp | 59 ++++++++ src/cpp11.cpp | 8 ++ src/stochtree_types.h | 1 + 10 files changed, 434 insertions(+), 4 deletions(-) create mode 100644 man/bart_specialized.Rd create mode 100644 src/R_bart.cpp diff --git a/NAMESPACE b/NAMESPACE index ab87b7b9..4b3af259 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -5,6 +5,7 @@ S3method(getRandomEffectSamples,bcf) S3method(predict,bartmodel) S3method(predict,bcf) export(bart) +export(bart_specialized) export(bcf) export(computeForestKernels) export(computeForestLeafIndices) diff --git a/R/bart.R b/R/bart.R index 2ccebace..8e1ff8e6 100644 --- a/R/bart.R +++ b/R/bart.R @@ -622,6 +622,247 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL } } +#' Run the BART algorithm for supervised learning. +#' +#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. +#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be +#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, +#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata +#' that the column is ordered categorical). +#' @param y_train Outcome to be modeled by the ensemble. +#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. +#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with +#' that of `X_train`. +#' @param cutpoint_grid_size Maximum size of the "grid" of potential cutpoints to consider. Default: 100. +#' @param tau_init Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. +#' @param alpha Prior probability of splitting for a tree of depth 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. +#' @param beta Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. +#' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0. +#' @param min_samples_leaf Minimum allowable size of a leaf, in terms of training samples. Default: 5. +#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3. +#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). +#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3. +#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. +#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9. +#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here. +#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. +#' @param num_trees Number of trees in the ensemble. Default: 200. +#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. +#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. +#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. +#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. +#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0. +#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0. +#' @param verbose Whether or not to print progress during the sampling loops. Default: FALSE. +#' +#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test) +#' # plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") +#' # abline(0,1,col="red",lty=3,lwd=3) +bart_specialized <- function( + X_train, y_train, X_test = NULL, cutpoint_grid_size = 100, + tau_init = NULL, alpha = 0.95, beta = 2.0, min_samples_leaf = 5, + nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, + q = 0.9, sigma2_init = NULL, variable_weights = NULL, + num_trees = 200, num_gfr = 5, num_burnin = 0, num_mcmc = 100, + random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F +){ + # Variable weight preprocessing (and initialization if necessary) + if (is.null(variable_weights)) { + variable_weights = rep(1/ncol(X_train), ncol(X_train)) + } + if (any(variable_weights < 0)) { + stop("variable_weights cannot have any negative weights") + } + + # Preprocess covariates + if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { + stop("X_train must be a matrix or dataframe") + } + if (!is.null(X_test)){ + if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { + stop("X_test must be a matrix or dataframe") + } + } + if (ncol(X_train) != length(variable_weights)) { + stop("length(variable_weights) must equal ncol(X_train)") + } + train_cov_preprocess_list <- preprocessTrainData(X_train) + X_train_metadata <- train_cov_preprocess_list$metadata + X_train <- train_cov_preprocess_list$data + original_var_indices <- X_train_metadata$original_var_indices + feature_types <- X_train_metadata$feature_types + feature_types <- as.integer(feature_types) + if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata) + + # Update variable weights + variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x)) + variable_weights <- variable_weights[original_var_indices]*variable_weights_adj + + # Data consistency checks + if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { + stop("X_train and X_test must have the same number of columns") + } + if (nrow(X_train) != length(y_train)) { + stop("X_train and y_train must have the same number of observations") + } + + # Convert y_train to numeric vector if not already converted + if (!is.null(dim(y_train))) { + y_train <- as.matrix(y_train) + } + + # Determine whether a basis vector is provided + has_basis = F + + # Determine whether a test set is provided + has_test = !is.null(X_test) + + # Standardize outcome separately for test and train + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) + resid_train <- (y_train-y_bar_train)/y_std_train + + # Calibrate priors for sigma^2 and tau + reg_basis <- X_train + sigma2hat <- (sigma(lm(resid_train~reg_basis)))^2 + quantile_cutoff <- 0.9 + if (is.null(lambda)) { + lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu + } + if (is.null(sigma2_init)) sigma2_init <- sigma2hat + if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees) + if (is.null(tau_init)) tau_init <- var(resid_train)/(num_trees) + current_leaf_scale <- as.matrix(tau_init) + current_sigma2 <- sigma2_init + + # Determine leaf model type + leaf_model <- 0 + + # Unpack model type info + output_dimension = 1 + is_leaf_constant = T + leaf_regression = F + + # Container of variance parameter samples + num_samples <- num_gfr + num_burnin + num_mcmc + + # Run the BART sampler + bart_result_ptr <- run_bart_cpp( + as.numeric(X_train), y_train, feature_types, variable_weights, nrow(X_train), + ncol(X_train), num_trees, output_dimension, is_leaf_constant, alpha, beta, + min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lambda, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed + ) +# +# # Forest predictions +# y_hat_train <- forest_samples$predict(forest_dataset_train)*y_std_train + y_bar_train +# if (has_test) y_hat_test <- forest_samples$predict(forest_dataset_test)*y_std_train + y_bar_train +# +# # Random effects predictions +# if (has_rfx) { +# rfx_preds_train <- rfx_samples$predict(group_ids_train, rfx_basis_train)*y_std_train +# y_hat_train <- y_hat_train + rfx_preds_train +# } +# if ((has_rfx_test) && (has_test)) { +# rfx_preds_test <- rfx_samples$predict(group_ids_test, rfx_basis_test)*y_std_train +# y_hat_test <- y_hat_test + rfx_preds_test +# } + + # # Compute retention indices + # if (num_mcmc > 0) { + # keep_indices = mcmc_indices + # if (keep_gfr) keep_indices <- c(gfr_indices, keep_indices) + # if (keep_burnin) keep_indices <- c(burnin_indices, keep_indices) + # } else { + # if ((num_gfr > 0) && (num_burnin > 0)) { + # # Override keep_gfr = FALSE since there are no MCMC samples + # # Don't retain both GFR and burnin samples + # keep_indices = gfr_indices + # } else if ((num_gfr <= 0) && (num_burnin > 0)) { + # # Override keep_burnin = FALSE since there are no MCMC or GFR samples + # keep_indices = burnin_indices + # } else if ((num_gfr > 0) && (num_burnin <= 0)) { + # # Override keep_gfr = FALSE since there are no MCMC samples + # keep_indices = gfr_indices + # } else { + # stop("There are no samples to retain!") + # } + # } + # + # # Subset forest and RFX predictions + # y_hat_train <- y_hat_train[,keep_indices] + # if (has_test) { + # y_hat_test <- y_hat_test[,keep_indices] + # } + # + # # Global error variance + # if (sample_sigma) sigma2_samples <- global_var_samples[keep_indices]*(y_std_train^2) + + # Return results as a list + model_params <- list( + "sigma2_init" = sigma2_init, + "nu" = nu, + "lambda" = lambda, + "tau_init" = tau_init, + "a" = a_leaf, + "b" = b_leaf, + "outcome_mean" = y_bar_train, + "outcome_scale" = y_std_train, + "output_dimension" = output_dimension, + "is_leaf_constant" = is_leaf_constant, + "leaf_regression" = leaf_regression, + "requires_basis" = F, + "num_covariates" = ncol(X_train), + "num_basis" = 0, + "num_samples" = num_samples, + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, + "has_basis" = F, + "has_rfx" = F, + "has_rfx_basis" = F, + "num_rfx_basis" = 0, + "sample_sigma" = T, + "sample_tau" = F + ) + result <- list( + # "forests" = forest_samples, + "model_params" = model_params + # "y_hat_train" = y_hat_train, + # "train_set_metadata" = X_train_metadata, + # "keep_indices" = keep_indices + ) + # if (has_test) result[["y_hat_test"]] = y_hat_test + # if (sample_sigma) result[["sigma2_samples"]] = sigma2_samples + class(result) <- "simplifiedbart" + + return(result) +} + #' Extract raw sample values for each of the random effect parameter terms. #' #' @param object Object of type `bcf` containing draws of a Bayesian causal forest model and associated sampling outputs. diff --git a/R/cpp11.R b/R/cpp11.R index 1d08b6a1..862b51be 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -1,5 +1,9 @@ # Generated by cpp11: do not edit by hand +run_bart_cpp <- function(covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed) { + .Call(`_stochtree_run_bart_cpp`, covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed) +} + create_forest_dataset_cpp <- function() { .Call(`_stochtree_create_forest_dataset_cpp`) } diff --git a/man/bart_specialized.Rd b/man/bart_specialized.Rd new file mode 100644 index 00000000..d270d650 --- /dev/null +++ b/man/bart_specialized.Rd @@ -0,0 +1,118 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{bart_specialized} +\alias{bart_specialized} +\title{Run the BART algorithm for supervised learning.} +\usage{ +bart_specialized( + X_train, + y_train, + X_test = NULL, + cutpoint_grid_size = 100, + tau_init = NULL, + alpha = 0.95, + beta = 2, + min_samples_leaf = 5, + nu = 3, + lambda = NULL, + a_leaf = 3, + b_leaf = NULL, + q = 0.9, + sigma2_init = NULL, + variable_weights = NULL, + num_trees = 200, + num_gfr = 5, + num_burnin = 0, + num_mcmc = 100, + random_seed = -1, + keep_burnin = F, + keep_gfr = F, + verbose = F +) +} +\arguments{ +\item{X_train}{Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. +Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be +preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, +categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata +that the column is ordered categorical).} + +\item{y_train}{Outcome to be modeled by the ensemble.} + +\item{X_test}{(Optional) Test set of covariates used to define "out of sample" evaluation data. +May be provided either as a dataframe or a matrix, but the format of \code{X_test} must be consistent with +that of \code{X_train}.} + +\item{cutpoint_grid_size}{Maximum size of the "grid" of potential cutpoints to consider. Default: 100.} + +\item{tau_init}{Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here.} + +\item{alpha}{Prior probability of splitting for a tree of depth 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.} + +\item{beta}{Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.} + +\item{min_samples_leaf}{Minimum allowable size of a leaf, in terms of training samples. Default: 5.} + +\item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.} + +\item{lambda}{Component of the scale parameter in the \code{IG(nu, nu*lambda)} global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).} + +\item{a_leaf}{Shape parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Default: 3.} + +\item{b_leaf}{Scale parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Calibrated internally as \code{0.5/num_trees} if not set here.} + +\item{q}{Quantile used to calibrated \code{lambda} as in Sparapani et al (2021). Default: 0.9.} + +\item{sigma2_init}{Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.} + +\item{variable_weights}{Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here.} + +\item{num_trees}{Number of trees in the ensemble. Default: 200.} + +\item{num_gfr}{Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.} + +\item{num_burnin}{Number of "burn-in" iterations of the MCMC sampler. Default: 0.} + +\item{num_mcmc}{Number of "retained" iterations of the MCMC sampler. Default: 100.} + +\item{random_seed}{Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}.} + +\item{keep_burnin}{Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.} + +\item{keep_gfr}{Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.} + +\item{verbose}{Whether or not to print progress during the sampling loops. Default: FALSE.} + +\item{leaf_model}{Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.} +} +\value{ +List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). +} +\description{ +Run the BART algorithm for supervised learning. +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test) +# plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") +# abline(0,1,col="red",lty=3,lwd=3) +} diff --git a/man/preprocessTrainDataFrame.Rd b/man/preprocessTrainDataFrame.Rd index d4b5b23a..66cbe41c 100644 --- a/man/preprocessTrainDataFrame.Rd +++ b/man/preprocessTrainDataFrame.Rd @@ -11,8 +11,6 @@ preprocessTrainDataFrame(input_df) \arguments{ \item{input_df}{Dataframe of covariates. Users must pre-process any categorical variables as factors (ordered for ordered categorical).} - -\item{variable_weights}{Numeric weights reflecting the relative probability of splitting on each variable} } \value{ List with preprocessed data and details on the number of each type diff --git a/man/preprocessTrainMatrix.Rd b/man/preprocessTrainMatrix.Rd index b8cfd268..b90f7afe 100644 --- a/man/preprocessTrainMatrix.Rd +++ b/man/preprocessTrainMatrix.Rd @@ -9,8 +9,6 @@ preprocessTrainMatrix(input_matrix) } \arguments{ \item{input_matrix}{Covariate matrix.} - -\item{variable_weights}{Numeric weights reflecting the relative probability of splitting on each variable} } \value{ List with preprocessed (unmodified) data and details on the number of each type diff --git a/src/Makevars b/src/Makevars index 53848f54..0714677b 100644 --- a/src/Makevars +++ b/src/Makevars @@ -10,11 +10,13 @@ CXX_STD=CXX17 OBJECTS = \ forest.o \ kernel.o \ + R_bart.o \ R_data.o \ R_random_effects.o \ sampler.o \ serialization.o \ cpp11.o \ + bart.o \ container.o \ cutpoint_candidates.o \ data.o \ diff --git a/src/R_bart.cpp b/src/R_bart.cpp new file mode 100644 index 00000000..2df7f428 --- /dev/null +++ b/src/R_bart.cpp @@ -0,0 +1,59 @@ +#include +#include "stochtree_types.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +[[cpp11::register]] +cpp11::external_pointer run_bart_cpp( + cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, + cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, + int output_dimension, bool is_leaf_constant, double alpha, double beta, + int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, + double nu, double lamb, double leaf_variance_init, double global_variance_init, + int num_gfr, int num_burnin, int num_mcmc, int random_seed +) { + // Create smart pointer to newly allocated object + std::unique_ptr bart_result_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast(feature_types[i]); + } + + // Create BART dispatcher and add data + StochTree::BARTDispatcher bart_dispatcher{}; + double* covariate_data_ptr = REAL(PROTECT(covariates)); + double* outcome_data_ptr = REAL(PROTECT(outcome)); + bart_dispatcher.AddDataset(covariate_data_ptr, num_rows, num_covariates, false, true); + bart_dispatcher.AddTrainOutcome(outcome_data_ptr, num_rows); + + // Run the BART sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_variance_init, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size + ); + + // Unprotect pointers to R data + UNPROTECT(2); + + // Release management of the pointer to R session + return cpp11::external_pointer(bart_result_ptr_.release()); +} + diff --git a/src/cpp11.cpp b/src/cpp11.cpp index d8d0182b..30c8aa35 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -5,6 +5,13 @@ #include "cpp11/declarations.hpp" #include +// R_bart.cpp +cpp11::external_pointer run_bart_cpp(cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, double nu, double lamb, double leaf_variance_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed); +extern "C" SEXP _stochtree_run_bart_cpp(SEXP covariates, SEXP outcome, SEXP feature_types, SEXP variable_weights, SEXP num_rows, SEXP num_covariates, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP leaf_variance_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp(cpp11::as_cpp>(covariates), cpp11::as_cpp>(outcome), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_rows), cpp11::as_cpp>(num_covariates), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(leaf_variance_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed))); + END_CPP11 +} // R_data.cpp cpp11::external_pointer create_forest_dataset_cpp(); extern "C" SEXP _stochtree_create_forest_dataset_cpp() { @@ -869,6 +876,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, + {"_stochtree_run_bart_cpp", (DL_FUNC) &_stochtree_run_bart_cpp, 23}, {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 13}, {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 13}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 4}, diff --git a/src/stochtree_types.h b/src/stochtree_types.h index 096dc58d..584f19b4 100644 --- a/src/stochtree_types.h +++ b/src/stochtree_types.h @@ -1,3 +1,4 @@ +#include #include #include #include From 4249f38cbecf39aa9e34e19c46812135efd185a6 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 15 Jul 2024 00:02:56 -0400 Subject: [PATCH 05/18] Updated BART classes to use unique pointer to ForestContainer (for easier linking to R and Python) and to include implementation in header file (to make possible to template the entire class) --- include/stochtree/bart.h | 122 ++++++++++++++++++++++++++++++++++++--- src/Makevars | 1 - src/bart.cpp | 117 ------------------------------------- 3 files changed, 114 insertions(+), 126 deletions(-) delete mode 100644 src/bart.cpp diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index 9b87db7d..242cbc65 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -12,14 +12,17 @@ #include #include +#include + namespace StochTree { class BARTResult { public: - BARTResult(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) : - forests_samples_{num_trees, output_dimension, is_leaf_constant} {} + BARTResult(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) { + forest_samples_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); + } ~BARTResult() {} - ForestContainer& GetForests() {return forests_samples_;} + ForestContainer* GetForests() {return forest_samples_.get();} std::vector& GetTrainPreds() {return raw_preds_train_;} std::vector& GetTestPreds() {return raw_preds_test_;} std::vector& GetVarianceSamples() {return sigma_samples_;} @@ -30,7 +33,7 @@ class BARTResult { int NumTestObservations() {return num_test_;} bool HasTestSet() {return has_test_set_;} private: - ForestContainer forests_samples_; + std::unique_ptr forest_samples_; std::vector raw_preds_train_; std::vector raw_preds_test_; std::vector sigma_samples_; @@ -46,15 +49,118 @@ class BARTDispatcher { public: BARTDispatcher() {} ~BARTDispatcher() {} - BARTResult CreateOutputObject(int num_trees, int output_dimension = 1, bool is_leaf_constant = true); + + BARTResult CreateOutputObject(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) { + return BARTResult(num_trees, output_dimension, is_leaf_constant); + } + + void AddDataset(double* covariates, data_size_t num_row, int num_col, bool is_row_major, bool train) { + if (train) { + train_dataset_ = ForestDataset(); + train_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); + num_train_ = num_row; + } else { + test_dataset_ = ForestDataset(); + test_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); + has_test_set_ = true; + num_test_ = num_row; + } + } + + void AddTrainOutcome(double* outcome, data_size_t num_row) { + train_outcome_ = ColumnVector(); + train_outcome_.LoadData(outcome, num_row); + } + void RunSampler( BARTResult& output, std::vector& feature_types, std::vector& variable_weights, int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, double leaf_var_init, double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, int cutpoint_grid_size, int random_seed = -1 - ); - void AddDataset(double* covariates, data_size_t num_row, int num_col, bool is_row_major, bool train); - void AddTrainOutcome(double* outcome, data_size_t num_row); + ) { + // Unpack sampling details + num_gfr_ = num_gfr; + num_burnin_ = num_burnin; + num_mcmc_ = num_mcmc; + int num_samples = num_gfr + num_burnin + num_mcmc; + + // Random number generation + std::mt19937 rng; + if (random_seed == -1) { + std::random_device rd; + std::mt19937 rng(rd()); + } + else { + std::mt19937 rng(random_seed); + } + + // Obtain references to forest / parameter samples and predictions in BARTResult + ForestContainer* forest_samples = output.GetForests(); + std::vector& sigma2_samples = output.GetVarianceSamples(); + std::vector& train_preds = output.GetTrainPreds(); + std::vector& test_preds = output.GetTestPreds(); + + // Clear and prepare vectors to store results + sigma2_samples.clear(); + train_preds.clear(); + test_preds.clear(); + sigma2_samples.resize(num_samples); + train_preds.resize(num_samples*num_train_); + if (has_test_set_) test_preds.resize(num_samples*num_test_); + + // Initialize tracker and tree prior + ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_); + TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf); + + // Initialize variance model + GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); + + // Initialize leaf model and samplers + GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_var_init); + GFRForestSampler gfr_sampler = GFRForestSampler(cutpoint_grid_size); + MCMCForestSampler mcmc_sampler = MCMCForestSampler(); + + // Running variable for current sampled value of global outcome variance parameter + double global_var = global_var_init; + + // Run the XBART Gibbs sampler + int iter = 0; + if (num_gfr > 0) { + for (int i = 0; i < num_gfr; i++) { + // Sample the forests + gfr_sampler.SampleOneIter(tracker, *forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, + rng, variable_weights, global_var, feature_types, false); + + // Sample the global outcome + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + + // Increment sample counter + iter++; + } + } + + // Run the MCMC sampler + if (num_burnin + num_mcmc > 0) { + for (int i = 0; i < num_burnin + num_mcmc; i++) { + // Sample the forests + mcmc_sampler.SampleOneIter(tracker, *forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, + rng, variable_weights, global_var, true); + + // Sample the global outcome + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + + // Increment sample counter + iter++; + } + } + + // Predict forests + forest_samples->PredictInPlace(train_dataset_, train_preds); + if (has_test_set_) forest_samples->PredictInPlace(test_dataset_, test_preds); + } + private: // Sampling details int num_gfr_{0}; diff --git a/src/Makevars b/src/Makevars index 0714677b..4cf92c4a 100644 --- a/src/Makevars +++ b/src/Makevars @@ -16,7 +16,6 @@ OBJECTS = \ sampler.o \ serialization.o \ cpp11.o \ - bart.o \ container.o \ cutpoint_candidates.o \ data.o \ diff --git a/src/bart.cpp b/src/bart.cpp deleted file mode 100644 index a1dc4ad8..00000000 --- a/src/bart.cpp +++ /dev/null @@ -1,117 +0,0 @@ -/*! Copyright (c) 2024 by stochtree authors */ -#include - -namespace StochTree { - -BARTResult BARTDispatcher::CreateOutputObject(int num_trees, int output_dimension, bool is_leaf_constant) { - return BARTResult(num_trees, output_dimension, is_leaf_constant); -} - -void BARTDispatcher::AddDataset(double* covariates, data_size_t num_row, int num_col, bool is_row_major, bool train) { - if (train) { - train_dataset_ = ForestDataset(); - train_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); - num_train_ = num_row; - } else { - test_dataset_ = ForestDataset(); - test_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); - has_test_set_ = true; - num_test_ = num_row; - } -} - -void BARTDispatcher::AddTrainOutcome(double* outcome, data_size_t num_row) { - train_outcome_ = ColumnVector(); - train_outcome_.LoadData(outcome, num_row); -} - -void BARTDispatcher::RunSampler( - BARTResult& output, std::vector& feature_types, std::vector& variable_weights, - int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, double leaf_var_init, - double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, - int cutpoint_grid_size, int random_seed -) { - // Unpack sampling details - num_gfr_ = num_gfr; - num_burnin_ = num_burnin; - num_mcmc_ = num_mcmc; - int num_samples = num_gfr + num_burnin + num_mcmc; - - // Random number generation - std::mt19937 rng; - if (random_seed == -1) { - std::random_device rd; - std::mt19937 rng(rd()); - } - else { - std::mt19937 rng(random_seed); - } - - // Obtain references to forest / parameter samples and predictions in BARTResult - ForestContainer& forest_samples = output.GetForests(); - std::vector& sigma2_samples = output.GetVarianceSamples(); - std::vector& train_preds = output.GetTrainPreds(); - std::vector& test_preds = output.GetTestPreds(); - - // Clear and prepare vectors to store results - sigma2_samples.clear(); - train_preds.clear(); - test_preds.clear(); - sigma2_samples.resize(num_samples); - train_preds.resize(num_samples*num_train_); - if (has_test_set_) test_preds.resize(num_samples*num_test_); - - // Initialize tracker and tree prior - ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_); - TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf); - - // Initialize variance model - GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); - - // Initialize leaf model and samplers - GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_var_init); - GFRForestSampler gfr_sampler = GFRForestSampler(cutpoint_grid_size); - MCMCForestSampler mcmc_sampler = MCMCForestSampler(); - - // Running variable for current sampled value of global outcome variance parameter - double global_var = global_var_init; - - // Run the XBART Gibbs sampler - int iter = 0; - if (num_gfr > 0) { - for (int i = 0; i < num_gfr; i++) { - // Sample the forests - gfr_sampler.SampleOneIter(tracker, forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, - rng, variable_weights, global_var, feature_types, false); - - // Sample the global outcome - global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); - sigma2_samples.at(iter) = global_var; - - // Increment sample counter - iter++; - } - } - - // Run the MCMC sampler - if (num_burnin + num_mcmc > 0) { - for (int i = 0; i < num_burnin + num_mcmc; i++) { - // Sample the forests - mcmc_sampler.SampleOneIter(tracker, forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, - rng, variable_weights, global_var, true); - - // Sample the global outcome - global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); - sigma2_samples.at(iter) = global_var; - - // Increment sample counter - iter++; - } - } - - // Predict forests - forest_samples.PredictInPlace(train_dataset_, train_preds); - if (has_test_set_) forest_samples.PredictInPlace(test_dataset_, test_preds); -} - -} // namespace StochTree From ff327109f2f9181a49ed8440c659b12222f67d42 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 15 Jul 2024 00:23:41 -0400 Subject: [PATCH 06/18] Updated R wrapper around templated BARTDispatcher class --- R/bart.R | 2 +- R/cpp11.R | 4 ++-- debug/api_debug.cpp | 2 +- include/stochtree/bart.h | 24 +++++++++++++++++--- src/R_bart.cpp | 49 ++++++++++++++++++++++++++++++---------- src/cpp11.cpp | 8 +++---- 6 files changed, 66 insertions(+), 23 deletions(-) diff --git a/R/bart.R b/R/bart.R index 8e1ff8e6..5c4f3ef4 100644 --- a/R/bart.R +++ b/R/bart.R @@ -775,7 +775,7 @@ bart_specialized <- function( as.numeric(X_train), y_train, feature_types, variable_weights, nrow(X_train), ncol(X_train), num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lambda, - tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, 0 ) # # # Forest predictions diff --git a/R/cpp11.R b/R/cpp11.R index 862b51be..f0add490 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -1,7 +1,7 @@ # Generated by cpp11: do not edit by hand -run_bart_cpp <- function(covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed) { - .Call(`_stochtree_run_bart_cpp`, covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed) +run_bart_cpp <- function(covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int) { + .Call(`_stochtree_run_bart_cpp`, covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int) } create_forest_dataset_cpp <- function() { diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index 0187bf58..98340b40 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -607,7 +607,7 @@ void RunDebugLoop(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_trees = 50; output_dimension = 1; is_leaf_constant = true; - BARTDispatcher bart_dispatcher{}; + BARTDispatcher bart_dispatcher{}; BARTResult bart_result = bart_dispatcher.CreateOutputObject(num_trees, output_dimension, is_leaf_constant); // Add covariates to sampling loop diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index 242cbc65..b4fc60ee 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -45,6 +45,7 @@ class BARTResult { bool has_test_set_{false}; }; +template class BARTDispatcher { public: BARTDispatcher() {} @@ -67,6 +68,21 @@ class BARTDispatcher { } } + void AddDataset(double* covariates, double* basis, data_size_t num_row, int num_covariates, int num_basis, bool is_row_major, bool train) { + if (train) { + train_dataset_ = ForestDataset(); + train_dataset_.AddCovariates(covariates, num_row, num_covariates, is_row_major); + train_dataset_.AddBasis(basis, num_row, num_basis, is_row_major); + num_train_ = num_row; + } else { + test_dataset_ = ForestDataset(); + test_dataset_.AddCovariates(covariates, num_row, num_covariates, is_row_major); + test_dataset_.AddBasis(basis, num_row, num_basis, is_row_major); + has_test_set_ = true; + num_test_ = num_row; + } + } + void AddTrainOutcome(double* outcome, data_size_t num_row) { train_outcome_ = ColumnVector(); train_outcome_.LoadData(outcome, num_row); @@ -116,9 +132,11 @@ class BARTDispatcher { GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); // Initialize leaf model and samplers - GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_var_init); - GFRForestSampler gfr_sampler = GFRForestSampler(cutpoint_grid_size); - MCMCForestSampler mcmc_sampler = MCMCForestSampler(); + // TODO: add template specialization for GaussianMultivariateRegressionLeafModel which takes Eigen::MatrixXd& + // as initialization parameter instead of double + ModelType leaf_model = ModelType(leaf_var_init); + GFRForestSampler gfr_sampler = GFRForestSampler(cutpoint_grid_size); + MCMCForestSampler mcmc_sampler = MCMCForestSampler(); // Running variable for current sampled value of global outcome variance parameter double global_var = global_var_init; diff --git a/src/R_bart.cpp b/src/R_bart.cpp index 2df7f428..046c8927 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -19,7 +19,7 @@ cpp11::external_pointer run_bart_cpp( int output_dimension, bool is_leaf_constant, double alpha, double beta, int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, double nu, double lamb, double leaf_variance_init, double global_variance_init, - int num_gfr, int num_burnin, int num_mcmc, int random_seed + int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int ) { // Create smart pointer to newly allocated object std::unique_ptr bart_result_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); @@ -37,18 +37,44 @@ cpp11::external_pointer run_bart_cpp( } // Create BART dispatcher and add data - StochTree::BARTDispatcher bart_dispatcher{}; double* covariate_data_ptr = REAL(PROTECT(covariates)); double* outcome_data_ptr = REAL(PROTECT(outcome)); - bart_dispatcher.AddDataset(covariate_data_ptr, num_rows, num_covariates, false, true); - bart_dispatcher.AddTrainOutcome(outcome_data_ptr, num_rows); - - // Run the BART sampling loop - bart_dispatcher.RunSampler( - *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, - num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_variance_init, - alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size - ); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + bart_dispatcher.AddDataset(covariate_data_ptr, num_rows, num_covariates, false, true); + bart_dispatcher.AddTrainOutcome(outcome_data_ptr, num_rows); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_variance_init, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + bart_dispatcher.AddDataset(covariate_data_ptr, num_rows, num_covariates, false, true); + bart_dispatcher.AddTrainOutcome(outcome_data_ptr, num_rows); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_variance_init, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size + ); + } + // // TODO: Figure out dispatch here + // else { + // // Create the dispatcher and load the data + // StochTree::BARTDispatcher bart_dispatcher{}; + // bart_dispatcher.AddDataset(covariate_data_ptr, num_rows, num_covariates, false, true); + // bart_dispatcher.AddTrainOutcome(outcome_data_ptr, num_rows); + // // Run the sampling loop + // bart_dispatcher.RunSampler( + // *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + // num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_variance_init, + // alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size + // ); + // } // Unprotect pointers to R data UNPROTECT(2); @@ -56,4 +82,3 @@ cpp11::external_pointer run_bart_cpp( // Release management of the pointer to R session return cpp11::external_pointer(bart_result_ptr_.release()); } - diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 30c8aa35..b0f61819 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -6,10 +6,10 @@ #include // R_bart.cpp -cpp11::external_pointer run_bart_cpp(cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, double nu, double lamb, double leaf_variance_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed); -extern "C" SEXP _stochtree_run_bart_cpp(SEXP covariates, SEXP outcome, SEXP feature_types, SEXP variable_weights, SEXP num_rows, SEXP num_covariates, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP leaf_variance_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed) { +cpp11::external_pointer run_bart_cpp(cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, double nu, double lamb, double leaf_variance_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int); +extern "C" SEXP _stochtree_run_bart_cpp(SEXP covariates, SEXP outcome, SEXP feature_types, SEXP variable_weights, SEXP num_rows, SEXP num_covariates, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP leaf_variance_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int) { BEGIN_CPP11 - return cpp11::as_sexp(run_bart_cpp(cpp11::as_cpp>(covariates), cpp11::as_cpp>(outcome), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_rows), cpp11::as_cpp>(num_covariates), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(leaf_variance_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed))); + return cpp11::as_sexp(run_bart_cpp(cpp11::as_cpp>(covariates), cpp11::as_cpp>(outcome), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_rows), cpp11::as_cpp>(num_covariates), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(leaf_variance_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int))); END_CPP11 } // R_data.cpp @@ -876,7 +876,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, - {"_stochtree_run_bart_cpp", (DL_FUNC) &_stochtree_run_bart_cpp, 23}, + {"_stochtree_run_bart_cpp", (DL_FUNC) &_stochtree_run_bart_cpp, 24}, {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 13}, {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 13}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 4}, From 5e982ca05de2626ff2df3cd252b54e9279735a25 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 15 Jul 2024 15:39:24 -0400 Subject: [PATCH 07/18] Broadening scope of the C++ only sampling loop --- R/bart.R | 4 +- R/cpp11.R | 4 +- debug/api_debug.cpp | 7 ++- include/stochtree/bart.h | 91 +++++++++++++++++++++++------- include/stochtree/leaf_model.h | 12 ++++ include/stochtree/random_effects.h | 49 ++++++++++------ src/R_bart.cpp | 28 ++++++--- src/cpp11.cpp | 6 +- src/random_effects.cpp | 24 ++++++++ 9 files changed, 172 insertions(+), 53 deletions(-) diff --git a/R/bart.R b/R/bart.R index 5c4f3ef4..3d78dbb1 100644 --- a/R/bart.R +++ b/R/bart.R @@ -755,8 +755,8 @@ bart_specialized <- function( } if (is.null(sigma2_init)) sigma2_init <- sigma2hat if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees) - if (is.null(tau_init)) tau_init <- var(resid_train)/(num_trees) - current_leaf_scale <- as.matrix(tau_init) + if (is.null(tau_init)) tau_init <- as.matrix(var(resid_train)/(num_trees)) + current_leaf_scale <- tau_init current_sigma2 <- sigma2_init # Determine leaf model type diff --git a/R/cpp11.R b/R/cpp11.R index f0add490..dba48240 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -1,7 +1,7 @@ # Generated by cpp11: do not edit by hand -run_bart_cpp <- function(covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int) { - .Call(`_stochtree_run_bart_cpp`, covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int) +run_bart_cpp <- function(covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int) { + .Call(`_stochtree_run_bart_cpp`, covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int) } create_forest_dataset_cpp <- function() { diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index 98340b40..717640ef 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -625,7 +625,8 @@ void RunDebugLoop(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, double b_leaf = 0.5 / num_trees; double nu = 3.; double lamb = 0.5; - double leaf_variance_init = 1. / num_trees; + Eigen::MatrixXd leaf_cov_init(1,1); + leaf_cov_init(0,0) = 1. / num_trees; double global_variance_init = 1.0; // Set variable weights @@ -634,8 +635,8 @@ void RunDebugLoop(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, // Run the BART sampling loop bart_dispatcher.RunSampler(bart_result, feature_types, variable_weights, num_trees, num_gfr, num_burnin, num_mcmc, - global_variance_init, leaf_variance_init, alpha, beta, nu, lamb, a_leaf, b_leaf, - min_samples_leaf, cutpoint_grid_size); + global_variance_init, leaf_cov_init, alpha, beta, nu, lamb, a_leaf, b_leaf, + min_samples_leaf, cutpoint_grid_size, true, false, -1); } void RunDebug(int dgp_num = 0, bool rfx_included = false, int num_gfr = 10, int num_burnin = 0, int num_mcmc = 100, int random_seed = -1, bool run_bart_loop = true) { diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index b4fc60ee..a19d4ea0 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -23,26 +23,41 @@ class BARTResult { } ~BARTResult() {} ForestContainer* GetForests() {return forest_samples_.get();} - std::vector& GetTrainPreds() {return raw_preds_train_;} - std::vector& GetTestPreds() {return raw_preds_test_;} - std::vector& GetVarianceSamples() {return sigma_samples_;} + ForestContainer* ReleaseForests() {return forest_samples_.release();} + RandomEffectsContainer* GetRFXContainer() {return rfx_container_.get();} + RandomEffectsContainer* ReleaseRFXContainer() {return rfx_container_.release();} + LabelMapper* GetRFXLabelMapper() {return rfx_label_mapper_.get();} + LabelMapper* ReleaseRFXLabelMapper() {return rfx_label_mapper_.release();} + std::vector& GetTrainPreds() {return outcome_preds_train_;} + std::vector& GetTestPreds() {return outcome_preds_test_;} + std::vector& GetGlobalVarianceSamples() {return sigma_samples_;} + std::vector& GetLeafVarianceSamples() {return tau_samples_;} int NumGFRSamples() {return num_gfr_;} int NumBurninSamples() {return num_burnin_;} int NumMCMCSamples() {return num_mcmc_;} int NumTrainObservations() {return num_train_;} int NumTestObservations() {return num_test_;} + bool IsGlobalVarRandom() {return is_global_var_random_;} + bool IsLeafVarRandom() {return is_leaf_var_random_;} bool HasTestSet() {return has_test_set_;} + bool HasRFX() {return has_rfx_;} private: std::unique_ptr forest_samples_; - std::vector raw_preds_train_; - std::vector raw_preds_test_; + std::unique_ptr rfx_container_; + std::unique_ptr rfx_label_mapper_; + std::vector outcome_preds_train_; + std::vector outcome_preds_test_; std::vector sigma_samples_; + std::vector tau_samples_; int num_gfr_{0}; int num_burnin_{0}; int num_mcmc_{0}; int num_train_{0}; int num_test_{0}; + bool is_global_var_random_{true}; + bool is_leaf_var_random_{false}; bool has_test_set_{false}; + bool has_rfx_{false}; }; template @@ -83,6 +98,23 @@ class BARTDispatcher { } } + void AddRFXTerm(double* rfx_basis, std::vector& rfx_group_indices, data_size_t num_row, int num_groups, int num_basis, bool is_row_major, bool train) { + if (train) { + rfx_train_dataset_ = RandomEffectsDataset(); + rfx_train_dataset_.AddBasis(rfx_basis, num_row, num_basis, is_row_major); + rfx_train_dataset_.AddGroupLabels(rfx_group_indices); + rfx_tracker_.Reset(rfx_group_indices); + rfx_model_.Reset(num_basis, num_groups); + num_rfx_groups_ = num_groups; + num_rfx_basis_ = num_basis; + has_rfx_ = true; + } else { + rfx_test_dataset_ = RandomEffectsDataset(); + rfx_test_dataset_.AddBasis(rfx_basis, num_row, num_basis, is_row_major); + rfx_test_dataset_.AddGroupLabels(rfx_group_indices); + } + } + void AddTrainOutcome(double* outcome, data_size_t num_row) { train_outcome_ = ColumnVector(); train_outcome_.LoadData(outcome, num_row); @@ -90,9 +122,9 @@ class BARTDispatcher { void RunSampler( BARTResult& output, std::vector& feature_types, std::vector& variable_weights, - int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, double leaf_var_init, - double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, - int cutpoint_grid_size, int random_seed = -1 + int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, Eigen::MatrixXd& leaf_cov_init, + double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, int cutpoint_grid_size, + bool sample_global_var, bool sample_leaf_var, int random_seed = -1 ) { // Unpack sampling details num_gfr_ = num_gfr; @@ -112,7 +144,8 @@ class BARTDispatcher { // Obtain references to forest / parameter samples and predictions in BARTResult ForestContainer* forest_samples = output.GetForests(); - std::vector& sigma2_samples = output.GetVarianceSamples(); + std::vector& sigma2_samples = output.GetGlobalVarianceSamples(); + std::vector& tau_samples = output.GetLeafVarianceSamples(); std::vector& train_preds = output.GetTrainPreds(); std::vector& test_preds = output.GetTestPreds(); @@ -128,13 +161,16 @@ class BARTDispatcher { ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_); TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf); - // Initialize variance model + // Initialize global variance model GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); + // Initialize leaf variance model + LeafNodeHomoskedasticVarianceModel leaf_var_model = LeafNodeHomoskedasticVarianceModel(); + // Initialize leaf model and samplers // TODO: add template specialization for GaussianMultivariateRegressionLeafModel which takes Eigen::MatrixXd& // as initialization parameter instead of double - ModelType leaf_model = ModelType(leaf_var_init); + ModelType leaf_model = ModelType(leaf_cov_init); GFRForestSampler gfr_sampler = GFRForestSampler(cutpoint_grid_size); MCMCForestSampler mcmc_sampler = MCMCForestSampler(); @@ -149,9 +185,11 @@ class BARTDispatcher { gfr_sampler.SampleOneIter(tracker, *forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, rng, variable_weights, global_var, feature_types, false); - // Sample the global outcome - global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); - sigma2_samples.at(iter) = global_var; + if (sample_global_var) { + // Sample the global outcome + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + } // Increment sample counter iter++; @@ -165,9 +203,11 @@ class BARTDispatcher { mcmc_sampler.SampleOneIter(tracker, *forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, rng, variable_weights, global_var, true); - // Sample the global outcome - global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); - sigma2_samples.at(iter) = global_var; + if (sample_global_var) { + // Sample the global outcome + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + } // Increment sample counter iter++; @@ -180,18 +220,29 @@ class BARTDispatcher { } private: - // Sampling details + // "Core" BART / XBART sampling objects + // Dimensions int num_gfr_{0}; int num_burnin_{0}; int num_mcmc_{0}; int num_train_{0}; int num_test_{0}; bool has_test_set_{false}; - - // Sampling data objects + // Data objects ForestDataset train_dataset_; ForestDataset test_dataset_; ColumnVector train_outcome_; + + // (Optional) random effect sampling details + // Dimensions + int num_rfx_groups_{0}; + int num_rfx_basis_{0}; + bool has_rfx_{false}; + // Data objects + RandomEffectsDataset rfx_train_dataset_; + RandomEffectsDataset rfx_test_dataset_; + RandomEffectsTracker rfx_tracker_; + MultivariateRegressionRandomEffectsModel rfx_model_; }; } // namespace StochTree diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 3ea7a8bb..e3709d98 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -66,6 +66,12 @@ class GaussianConstantSuffStat { class GaussianConstantLeafModel { public: GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} + GaussianConstantLeafModel(Eigen::MatrixXd& tau) { + CHECK_EQ(tau.rows(), 1); + CHECK_EQ(tau.cols(), 1); + tau_ = tau(0,0); + normal_sampler_ = UnivariateNormalSampler(); + } ~GaussianConstantLeafModel() {} std::tuple EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance); std::tuple EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id); @@ -132,6 +138,12 @@ class GaussianUnivariateRegressionSuffStat { class GaussianUnivariateRegressionLeafModel { public: GaussianUnivariateRegressionLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} + GaussianUnivariateRegressionLeafModel(Eigen::MatrixXd& tau) { + CHECK_EQ(tau.rows(), 1); + CHECK_EQ(tau.cols(), 1); + tau_ = tau(0,0); + normal_sampler_ = UnivariateNormalSampler(); + } ~GaussianUnivariateRegressionLeafModel() {} std::tuple EvaluateProposedSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance); std::tuple EvaluateExistingSplit(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id); diff --git a/include/stochtree/random_effects.h b/include/stochtree/random_effects.h index 7d7a65c0..5ece8b1a 100644 --- a/include/stochtree/random_effects.h +++ b/include/stochtree/random_effects.h @@ -32,23 +32,23 @@ namespace StochTree { class RandomEffectsTracker { public: RandomEffectsTracker(std::vector& group_indices); + RandomEffectsTracker(); ~RandomEffectsTracker() {} - inline data_size_t GetCategoryId(int observation_num) {return sample_category_mapper_->GetCategoryId(observation_num);} - inline data_size_t CategoryBegin(int category_id) {return category_sample_tracker_->CategoryBegin(category_id);} - inline data_size_t CategoryEnd(int category_id) {return category_sample_tracker_->CategoryEnd(category_id);} - inline data_size_t CategorySize(int category_id) {return category_sample_tracker_->CategorySize(category_id);} - inline int32_t NumCategories() {return num_categories_;} - inline int32_t CategoryNumber(int32_t category_id) {return category_sample_tracker_->CategoryNumber(category_id);} - SampleCategoryMapper* GetSampleCategoryMapper() {return sample_category_mapper_.get();} - CategorySampleTracker* GetCategorySampleTracker() {return category_sample_tracker_.get();} - std::vector::iterator UnsortedNodeBeginIterator(int category_id); - std::vector::iterator UnsortedNodeEndIterator(int category_id); - std::map& GetLabelMap() {return category_sample_tracker_->GetLabelMap();} - std::vector& GetUniqueGroupIds() {return category_sample_tracker_->GetUniqueGroupIds();} - std::vector& NodeIndices(int category_id) {return category_sample_tracker_->NodeIndices(category_id);} - std::vector& NodeIndicesInternalIndex(int internal_category_id) {return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);} - double GetPrediction(data_size_t observation_num) {return rfx_predictions_.at(observation_num);} - void SetPrediction(data_size_t observation_num, double pred) {rfx_predictions_.at(observation_num) = pred;} + void Reset(std::vector& group_indices); + inline data_size_t GetCategoryId(int observation_num) {CHECK(initialized_); return sample_category_mapper_->GetCategoryId(observation_num);} + inline data_size_t CategoryBegin(int category_id) {CHECK(initialized_); return category_sample_tracker_->CategoryBegin(category_id);} + inline data_size_t CategoryEnd(int category_id) {CHECK(initialized_); return category_sample_tracker_->CategoryEnd(category_id);} + inline data_size_t CategorySize(int category_id) {CHECK(initialized_); return category_sample_tracker_->CategorySize(category_id);} + inline int32_t NumCategories() {CHECK(initialized_); return num_categories_;} + inline int32_t CategoryNumber(int32_t category_id) {CHECK(initialized_); return category_sample_tracker_->CategoryNumber(category_id);} + SampleCategoryMapper* GetSampleCategoryMapper() {CHECK(initialized_); return sample_category_mapper_.get();} + CategorySampleTracker* GetCategorySampleTracker() {CHECK(initialized_); return category_sample_tracker_.get();} + std::map& GetLabelMap() {CHECK(initialized_); return category_sample_tracker_->GetLabelMap();} + std::vector& GetUniqueGroupIds() {CHECK(initialized_); return category_sample_tracker_->GetUniqueGroupIds();} + std::vector& NodeIndices(int category_id) {CHECK(initialized_); return category_sample_tracker_->NodeIndices(category_id);} + std::vector& NodeIndicesInternalIndex(int internal_category_id) {CHECK(initialized_); return category_sample_tracker_->NodeIndicesInternalIndex(internal_category_id);} + double GetPrediction(data_size_t observation_num) {CHECK(initialized_); return rfx_predictions_.at(observation_num);} + void SetPrediction(data_size_t observation_num, double pred) {CHECK(initialized_); rfx_predictions_.at(observation_num) = pred;} private: /*! \brief Mapper from observations to category indices */ @@ -60,6 +60,7 @@ class RandomEffectsTracker { /*! \brief Some high-level details of the random effects structure */ int num_categories_; int num_observations_; + bool initialized_{false}; }; /*! \brief Standalone container for the map from category IDs to 0-based indices */ @@ -100,8 +101,23 @@ class MultivariateRegressionRandomEffectsModel { group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_); group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_); working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_); + initialized_ = true; + } + MultivariateRegressionRandomEffectsModel() { + normal_sampler_ = MultivariateNormalSampler(); + ig_sampler_ = InverseGammaSampler(); + initialized_ = false; } ~MultivariateRegressionRandomEffectsModel() {} + void Reset(int num_components, int num_groups) { + num_components_ = num_components; + num_groups_ = num_groups; + working_parameter_ = Eigen::VectorXd(num_components_); + group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_); + group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_); + working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_); + initialized_ = true; + } /*! \brief Samplers */ void SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& tracker, double global_variance, std::mt19937& gen); @@ -228,6 +244,7 @@ class MultivariateRegressionRandomEffectsModel { /*! \brief Random effects structure details */ int num_components_; int num_groups_; + bool initialized_; /*! \brief Group mean parameters, decomposed into "working parameter" and individual parameters * under the "redundant" parameterization of Gelman et al (2008) diff --git a/src/R_bart.cpp b/src/R_bart.cpp index 046c8927..60127f1b 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -18,7 +18,7 @@ cpp11::external_pointer run_bart_cpp( cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, - double nu, double lamb, double leaf_variance_init, double global_variance_init, + double nu, double lamb, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int ) { // Create smart pointer to newly allocated object @@ -36,6 +36,17 @@ cpp11::external_pointer run_bart_cpp( feature_types_vector[i] = static_cast(feature_types[i]); } + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + // Create BART dispatcher and add data double* covariate_data_ptr = REAL(PROTECT(covariates)); double* outcome_data_ptr = REAL(PROTECT(outcome)); @@ -47,8 +58,9 @@ cpp11::external_pointer run_bart_cpp( // Run the sampling loop bart_dispatcher.RunSampler( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, - num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_variance_init, - alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + true, false, -1 ); } else if (leaf_model_int == 1) { // Create the dispatcher and load the data @@ -58,8 +70,9 @@ cpp11::external_pointer run_bart_cpp( // Run the sampling loop bart_dispatcher.RunSampler( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, - num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_variance_init, - alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + true, false, -1 ); } // // TODO: Figure out dispatch here @@ -71,8 +84,9 @@ cpp11::external_pointer run_bart_cpp( // // Run the sampling loop // bart_dispatcher.RunSampler( // *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, - // num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_variance_init, - // alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size + // num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + // alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + // true, false, -1 // ); // } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index b0f61819..73eab2cd 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -6,10 +6,10 @@ #include // R_bart.cpp -cpp11::external_pointer run_bart_cpp(cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, double nu, double lamb, double leaf_variance_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int); -extern "C" SEXP _stochtree_run_bart_cpp(SEXP covariates, SEXP outcome, SEXP feature_types, SEXP variable_weights, SEXP num_rows, SEXP num_covariates, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP leaf_variance_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int) { +cpp11::external_pointer run_bart_cpp(cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, double nu, double lamb, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int); +extern "C" SEXP _stochtree_run_bart_cpp(SEXP covariates, SEXP outcome, SEXP feature_types, SEXP variable_weights, SEXP num_rows, SEXP num_covariates, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int) { BEGIN_CPP11 - return cpp11::as_sexp(run_bart_cpp(cpp11::as_cpp>(covariates), cpp11::as_cpp>(outcome), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_rows), cpp11::as_cpp>(num_covariates), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(leaf_variance_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int))); + return cpp11::as_sexp(run_bart_cpp(cpp11::as_cpp>(covariates), cpp11::as_cpp>(outcome), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_rows), cpp11::as_cpp>(num_covariates), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int))); END_CPP11 } // R_data.cpp diff --git a/src/random_effects.cpp b/src/random_effects.cpp index bc746e81..0d74363c 100644 --- a/src/random_effects.cpp +++ b/src/random_effects.cpp @@ -9,6 +9,20 @@ RandomEffectsTracker::RandomEffectsTracker(std::vector& group_indices) num_categories_ = category_sample_tracker_->NumCategories(); num_observations_ = group_indices.size(); rfx_predictions_.resize(num_observations_, 0.); + initialized_ = true; +} + +RandomEffectsTracker::RandomEffectsTracker() { + initialized_ = false; +} + +void RandomEffectsTracker::Reset(std::vector& group_indices) { + sample_category_mapper_ = std::make_unique(group_indices); + category_sample_tracker_ = std::make_unique(group_indices); + num_categories_ = category_sample_tracker_->NumCategories(); + num_observations_ = group_indices.size(); + rfx_predictions_.resize(num_observations_, 0.); + initialized_ = true; } nlohmann::json LabelMapper::to_json() { @@ -41,6 +55,7 @@ void LabelMapper::from_json(const nlohmann::json& rfx_label_mapper_json) { void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { + CHECK(initialized_); // Update partial residual to add back in the random effects AddCurrentPredictionToResidual(dataset, rfx_tracker, residual); @@ -55,6 +70,7 @@ void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffects void MultivariateRegressionRandomEffectsModel::SampleWorkingParameter(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { + CHECK(initialized_); Eigen::VectorXd posterior_mean = WorkingParameterMean(dataset, residual, rfx_tracker, global_variance); Eigen::MatrixXd posterior_covariance = WorkingParameterVariance(dataset, residual, rfx_tracker, global_variance); working_parameter_ = normal_sampler_.SampleEigen(posterior_mean, posterior_covariance, gen); @@ -62,6 +78,7 @@ void MultivariateRegressionRandomEffectsModel::SampleWorkingParameter(RandomEffe void MultivariateRegressionRandomEffectsModel::SampleGroupParameters(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { + CHECK(initialized_); int32_t num_groups = num_groups_; Eigen::VectorXd posterior_mean; Eigen::MatrixXd posterior_covariance; @@ -75,6 +92,7 @@ void MultivariateRegressionRandomEffectsModel::SampleGroupParameters(RandomEffec void MultivariateRegressionRandomEffectsModel::SampleVarianceComponents(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, std::mt19937& gen) { + CHECK(initialized_); int32_t num_components = num_components_; double posterior_shape; double posterior_scale; @@ -88,6 +106,7 @@ void MultivariateRegressionRandomEffectsModel::SampleVarianceComponents(RandomEf Eigen::VectorXd MultivariateRegressionRandomEffectsModel::WorkingParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance){ + CHECK(initialized_); int32_t num_components = num_components_; int32_t num_groups = num_groups_; std::vector observation_indices; @@ -111,6 +130,7 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::WorkingParameterMean(R } Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::WorkingParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance){ + CHECK(initialized_); int32_t num_components = num_components_; int32_t num_groups = num_groups_; std::vector observation_indices; @@ -133,6 +153,7 @@ Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::WorkingParameterVarian } Eigen::VectorXd MultivariateRegressionRandomEffectsModel::GroupParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id) { + CHECK(initialized_); int32_t num_components = num_components_; int32_t num_groups = num_groups_; Eigen::MatrixXd X = dataset.GetBasis(); @@ -149,6 +170,7 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::GroupParameterMean(Ran } Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::GroupParameterVariance(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t group_id){ + CHECK(initialized_); int32_t num_components = num_components_; int32_t num_groups = num_groups_; Eigen::MatrixXd X = dataset.GetBasis(); @@ -165,10 +187,12 @@ Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::GroupParameterVariance } double MultivariateRegressionRandomEffectsModel::VarianceComponentShape(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id) { + CHECK(initialized_); return static_cast(variance_prior_shape_ + num_groups_); } double MultivariateRegressionRandomEffectsModel::VarianceComponentScale(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id) { + CHECK(initialized_); int32_t num_groups = num_groups_; Eigen::MatrixXd xi = group_parameters_; double output = variance_prior_scale_; From 75d9480c69c322e253982182fdad11b8314e965a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 18 Jul 2024 01:44:57 -0400 Subject: [PATCH 08/18] Updated C++ sampling loop to include more complete feature set --- R/bart.R | 328 +++++++-- R/cpp11.R | 32 +- include/stochtree/bart.h | 120 +++- include/stochtree/leaf_model.h | 2 + include/stochtree/random_effects.h | 13 +- man/bart_specialized.Rd | 44 +- src/R_bart.cpp | 1050 +++++++++++++++++++++++++++- src/cpp11.cpp | 64 +- src/stochtree_types.h | 1 + 9 files changed, 1529 insertions(+), 125 deletions(-) diff --git a/R/bart.R b/R/bart.R index 3d78dbb1..25fabd7d 100644 --- a/R/bart.R +++ b/R/bart.R @@ -630,9 +630,24 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL #' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata #' that the column is ordered categorical). #' @param y_train Outcome to be modeled by the ensemble. +#' @param W_train (Optional) Bases used to define a regression model `y ~ W` in +#' each leaf of each regression tree. By default, BART assumes constant leaf node +#' parameters, implicitly regressing on a constant basis of ones (i.e. `y ~ 1`). +#' @param group_ids_train (Optional) Group labels used for an additive random effects model. +#' @param rfx_basis_train (Optional) Basis for "random-slope" regression in an additive random effects model. +#' If `group_ids_train` is provided with a regression basis, an intercept-only random effects model +#' will be estimated. #' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. #' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with #' that of `X_train`. +#' @param W_test (Optional) Test set of bases used to define "out of sample" evaluation data. +#' While a test set is optional, the structure of any provided test set must match that +#' of the training set (i.e. if both X_train and W_train are provided, then a test set must +#' consist of X_test and W_test with the same number of columns). +#' @param group_ids_test (Optional) Test set group labels used for an additive random effects model. +#' We do not currently support (but plan to in the near future), test set evaluation for group labels +#' that were not in the training set. +#' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model. #' @param cutpoint_grid_size Maximum size of the "grid" of potential cutpoints to consider. Default: 100. #' @param tau_init Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. #' @param alpha Prior probability of splitting for a tree of depth 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. @@ -650,10 +665,14 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL #' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. #' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. #' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. +#' @param sample_sigma Whether or not to update the `sigma^2` global error variance parameter based on `IG(nu, nu*lambda)`. Default: T. +#' @param sample_tau Whether or not to update the `tau` leaf scale variance parameter based on `IG(a_leaf, b_leaf)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: T. #' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. #' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0. #' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0. #' @param verbose Whether or not to print progress during the sampling loops. Default: FALSE. +#' @param sample_global_var Whether or not global variance parameter should be sampled. Default: TRUE. +#' @param sample_leaf_var Whether or not leaf model variance parameter should be sampled. Default: FALSE. #' #' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). #' @export @@ -682,14 +701,17 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL #' bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test) #' # plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") #' # abline(0,1,col="red",lty=3,lwd=3) -bart_specialized <- function( - X_train, y_train, X_test = NULL, cutpoint_grid_size = 100, - tau_init = NULL, alpha = 0.95, beta = 2.0, min_samples_leaf = 5, - nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, - q = 0.9, sigma2_init = NULL, variable_weights = NULL, - num_trees = 200, num_gfr = 5, num_burnin = 0, num_mcmc = 100, - random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F -){ +bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, + rfx_basis_train = NULL, X_test = NULL, W_test = NULL, + group_ids_test = NULL, rfx_basis_test = NULL, + cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, + beta = 2.0, min_samples_leaf = 5, leaf_model = 0, + nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, + q = 0.9, sigma2_init = NULL, variable_weights = NULL, + num_trees = 200, num_gfr = 5, num_burnin = 0, + num_mcmc = 100, sample_sigma = T, sample_tau = T, + random_seed = -1, keep_burnin = F, keep_gfr = F, + verbose = F, sample_global_var = T, sample_leaf_var = F){ # Variable weight preprocessing (and initialization if necessary) if (is.null(variable_weights)) { variable_weights = rep(1/ncol(X_train), ncol(X_train)) @@ -713,33 +735,121 @@ bart_specialized <- function( train_cov_preprocess_list <- preprocessTrainData(X_train) X_train_metadata <- train_cov_preprocess_list$metadata X_train <- train_cov_preprocess_list$data + num_rows_train <- nrow(X_train) + num_cov_train <- ncol(X_train) + num_cov_test <- num_cov_train original_var_indices <- X_train_metadata$original_var_indices feature_types <- X_train_metadata$feature_types feature_types <- as.integer(feature_types) - if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata) + if (!is.null(X_test)) { + X_test <- preprocessPredictionData(X_test, X_train_metadata) + num_rows_test <- nrow(X_test) + } else { + num_rows_test <- 0 + } # Update variable weights variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x)) variable_weights <- variable_weights[original_var_indices]*variable_weights_adj + # Convert all input data to matrices if not already converted + if ((is.null(dim(W_train))) && (!is.null(W_train))) { + W_train <- as.matrix(W_train) + } + if ((is.null(dim(W_test))) && (!is.null(W_test))) { + W_test <- as.matrix(W_test) + } + if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) { + rfx_basis_train <- as.matrix(rfx_basis_train) + } + if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { + rfx_basis_test <- as.matrix(rfx_basis_test) + } + + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) + has_rfx <- F + has_rfx_test <- F + if (!is.null(group_ids_train)) { + group_ids_factor <- factor(group_ids_train) + group_ids_train <- as.integer(group_ids_factor) + has_rfx <- T + if (!is.null(group_ids_test)) { + group_ids_factor_test <- factor(group_ids_test, levels = levels(group_ids_factor)) + if (sum(is.na(group_ids_factor_test)) > 0) { + stop("All random effect group labels provided in group_ids_test must be present in group_ids_train") + } + group_ids_test <- as.integer(group_ids_factor_test) + has_rfx_test <- T + } + } + # Data consistency checks if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { stop("X_train and X_test must have the same number of columns") } + if ((!is.null(W_test)) && (ncol(W_test) != ncol(W_train))) { + stop("W_train and W_test must have the same number of columns") + } + if ((!is.null(W_train)) && (nrow(W_train) != nrow(X_train))) { + stop("W_train and X_train must have the same number of rows") + } + if ((!is.null(W_test)) && (nrow(W_test) != nrow(X_test))) { + stop("W_test and X_test must have the same number of rows") + } if (nrow(X_train) != length(y_train)) { stop("X_train and y_train must have the same number of observations") } - + if ((!is.null(rfx_basis_test)) && (ncol(rfx_basis_test) != ncol(rfx_basis_train))) { + stop("rfx_basis_train and rfx_basis_test must have the same number of columns") + } + if (!is.null(group_ids_train)) { + if (!is.null(group_ids_test)) { + if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { + stop("rfx_basis_train is provided but rfx_basis_test is not provided") + } + } + } + + # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + has_basis_rfx <- F + num_basis_rfx <- 0 + num_rfx_groups <- 0 + if (has_rfx) { + if (is.null(rfx_basis_train)) { + rfx_basis_train <- matrix(rep(1,nrow(X_train)), nrow = nrow(X_train), ncol = 1) + num_basis_rfx <- 1 + } else { + has_basis_rfx <- T + num_basis_rfx <- ncol(rfx_basis_train) + } + num_rfx_groups <- length(unique(group_ids_train)) + num_rfx_components <- ncol(rfx_basis_train) + if (num_rfx_groups == 1) warning("Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill") + } + if (has_rfx_test) { + if (is.null(rfx_basis_test)) { + if (!is.null(rfx_basis_train)) { + stop("Random effects basis provided for training set, must also be provided for the test set") + } + rfx_basis_test <- matrix(rep(1,nrow(X_test)), nrow = nrow(X_test), ncol = 1) + } + } + # Convert y_train to numeric vector if not already converted if (!is.null(dim(y_train))) { y_train <- as.matrix(y_train) } # Determine whether a basis vector is provided - has_basis = F + has_basis = !is.null(W_train) + if (has_basis) num_basis_train <- ncol(W_train) + else num_basis_train <- 0 + num_basis_test <- num_basis_train # Determine whether a test set is provided has_test = !is.null(X_test) + if (has_test) num_test <- nrow(X_test) + else num_test <- 0 # Standardize outcome separately for test and train y_bar_train <- mean(y_train) @@ -747,7 +857,7 @@ bart_specialized <- function( resid_train <- (y_train-y_bar_train)/y_std_train # Calibrate priors for sigma^2 and tau - reg_basis <- X_train + reg_basis <- cbind(W_train, X_train) sigma2hat <- (sigma(lm(resid_train~reg_basis)))^2 quantile_cutoff <- 0.9 if (is.null(lambda)) { @@ -755,72 +865,148 @@ bart_specialized <- function( } if (is.null(sigma2_init)) sigma2_init <- sigma2hat if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees) - if (is.null(tau_init)) tau_init <- as.matrix(var(resid_train)/(num_trees)) - current_leaf_scale <- tau_init + if (is.null(tau_init)) tau_init <- var(resid_train)/(num_trees) + current_leaf_scale <- as.matrix(tau_init) current_sigma2 <- sigma2_init # Determine leaf model type - leaf_model <- 0 + if (!has_basis) leaf_model <- 0 + else if (ncol(W_train) == 1) leaf_model <- 1 + else if (ncol(W_train) > 1) leaf_model <- 2 + else stop("W_train passed must be a matrix with at least 1 column") # Unpack model type info - output_dimension = 1 - is_leaf_constant = T - leaf_regression = F - - # Container of variance parameter samples - num_samples <- num_gfr + num_burnin + num_mcmc + if (leaf_model == 0) { + output_dimension = 1 + is_leaf_constant = T + leaf_regression = F + } else if (leaf_model == 1) { + stopifnot(has_basis) + stopifnot(ncol(W_train) == 1) + output_dimension = 1 + is_leaf_constant = F + leaf_regression = T + } else if (leaf_model == 2) { + stopifnot(has_basis) + stopifnot(ncol(W_train) > 1) + output_dimension = ncol(W_train) + is_leaf_constant = F + leaf_regression = T + if (sample_tau) { + stop("Sampling leaf scale not yet supported for multivariate leaf models") + } + } + + # Random effects prior parameters + alpha_init <- as.numeric(NULL) + xi_init <- as.numeric(NULL) + sigma_alpha_init <- as.numeric(NULL) + sigma_xi_init <- as.numeric(NULL) + sigma_xi_shape <- NULL + sigma_xi_scale <- NULL + if (has_rfx) { + if (num_rfx_components == 1) { + alpha_init <- c(1) + } else if (num_rfx_components > 1) { + alpha_init <- c(1,rep(0,num_rfx_components-1)) + } else { + stop("There must be at least 1 random effect component") + } + xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups) + sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components) + sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components) + sigma_xi_shape <- 1 + sigma_xi_scale <- 1 + } # Run the BART sampler - bart_result_ptr <- run_bart_cpp( - as.numeric(X_train), y_train, feature_types, variable_weights, nrow(X_train), - ncol(X_train), num_trees, output_dimension, is_leaf_constant, alpha, beta, - min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lambda, - tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, 0 - ) -# -# # Forest predictions -# y_hat_train <- forest_samples$predict(forest_dataset_train)*y_std_train + y_bar_train -# if (has_test) y_hat_test <- forest_samples$predict(forest_dataset_test)*y_std_train + y_bar_train -# -# # Random effects predictions -# if (has_rfx) { -# rfx_preds_train <- rfx_samples$predict(group_ids_train, rfx_basis_train)*y_std_train -# y_hat_train <- y_hat_train + rfx_preds_train -# } -# if ((has_rfx_test) && (has_test)) { -# rfx_preds_test <- rfx_samples$predict(group_ids_test, rfx_basis_test)*y_std_train -# y_hat_test <- y_hat_test + rfx_preds_test -# } - - # # Compute retention indices - # if (num_mcmc > 0) { - # keep_indices = mcmc_indices - # if (keep_gfr) keep_indices <- c(gfr_indices, keep_indices) - # if (keep_burnin) keep_indices <- c(burnin_indices, keep_indices) - # } else { - # if ((num_gfr > 0) && (num_burnin > 0)) { - # # Override keep_gfr = FALSE since there are no MCMC samples - # # Don't retain both GFR and burnin samples - # keep_indices = gfr_indices - # } else if ((num_gfr <= 0) && (num_burnin > 0)) { - # # Override keep_burnin = FALSE since there are no MCMC or GFR samples - # keep_indices = burnin_indices - # } else if ((num_gfr > 0) && (num_burnin <= 0)) { - # # Override keep_gfr = FALSE since there are no MCMC samples - # keep_indices = gfr_indices - # } else { - # stop("There are no samples to retain!") - # } - # } - # - # # Subset forest and RFX predictions - # y_hat_train <- y_hat_train[,keep_indices] - # if (has_test) { - # y_hat_test <- y_hat_test[,keep_indices] - # } - # - # # Global error variance - # if (sample_sigma) sigma2_samples <- global_var_samples[keep_indices]*(y_std_train^2) + if ((has_basis) && (has_test) && (has_rfx)) { + bart_result_ptr <- run_bart_cpp_basis_test_rfx( + as.numeric(X_train), as.numeric(W_train), resid_train, + num_rows_train, num_cov_train, num_basis_train, + as.numeric(X_test), as.numeric(W_test), num_rows_test, num_cov_test, num_basis_test, + as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, + as.numeric(rfx_basis_test), group_ids_test, num_basis_rfx, num_rfx_groups, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale + ) + } else if ((has_basis) && (has_test) && (!has_rfx)) { + bart_result_ptr <- run_bart_cpp_basis_test_norfx( + as.numeric(X_train), as.numeric(W_train), resid_train, + num_rows_train, num_cov_train, num_basis_train, + as.numeric(X_test), as.numeric(W_test), num_rows_test, num_cov_test, num_basis_test, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var + ) + } else if ((has_basis) && (!has_test) && (has_rfx)) { + bart_result_ptr <- run_bart_cpp_basis_notest_rfx( + as.numeric(X_train), as.numeric(W_train), resid_train, + num_rows_train, num_cov_train, num_basis_train, + as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale + ) + } else if ((has_basis) && (!has_test) && (!has_rfx)) { + bart_result_ptr <- run_bart_cpp_basis_notest_norfx( + as.numeric(X_train), as.numeric(W_train), resid_train, + num_rows_train, num_cov_train, num_basis_train, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var + ) + } else if ((!has_basis) && (has_test) && (has_rfx)) { + bart_result_ptr <- run_bart_cpp_nobasis_test_rfx( + as.numeric(X_train), resid_train, + num_rows_train, num_cov_train, + as.numeric(X_test), num_rows_test, num_cov_test, + as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, + as.numeric(rfx_basis_test), group_ids_test, num_basis_rfx, num_rfx_groups, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale + ) + } else if ((!has_basis) && (has_test) && (!has_rfx)) { + bart_result_ptr <- run_bart_cpp_nobasis_test_norfx( + as.numeric(X_train), resid_train, + num_rows_train, num_cov_train, + as.numeric(X_test), num_rows_test, num_cov_test, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var + ) + } else if ((!has_basis) && (!has_test) && (has_rfx)) { + bart_result_ptr <- run_bart_cpp_nobasis_notest_rfx( + as.numeric(X_train), resid_train, + num_rows_train, num_cov_train, + as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale + ) + } else if ((!has_basis) && (!has_test) && (!has_rfx)) { + bart_result_ptr <- run_bart_cpp_nobasis_notest_norfx( + as.numeric(X_train), resid_train, + num_rows_train, num_cov_train, + feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, + sample_global_var, sample_leaf_var + ) + } # Return results as a list model_params <- list( diff --git a/R/cpp11.R b/R/cpp11.R index dba48240..1a4e6ac7 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -1,7 +1,35 @@ # Generated by cpp11: do not edit by hand -run_bart_cpp <- function(covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int) { - .Call(`_stochtree_run_bart_cpp`, covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int) +run_bart_cpp_basis_test_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_basis_test_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +} + +run_bart_cpp_basis_test_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_basis_test_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) +} + +run_bart_cpp_basis_notest_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_basis_notest_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +} + +run_bart_cpp_basis_notest_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_basis_notest_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) +} + +run_bart_cpp_nobasis_test_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_nobasis_test_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +} + +run_bart_cpp_nobasis_test_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_nobasis_test_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) +} + +run_bart_cpp_nobasis_notest_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_nobasis_notest_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +} + +run_bart_cpp_nobasis_notest_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_nobasis_notest_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) } create_forest_dataset_cpp <- function() { diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index a19d4ea0..b6364744 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -28,8 +28,12 @@ class BARTResult { RandomEffectsContainer* ReleaseRFXContainer() {return rfx_container_.release();} LabelMapper* GetRFXLabelMapper() {return rfx_label_mapper_.get();} LabelMapper* ReleaseRFXLabelMapper() {return rfx_label_mapper_.release();} - std::vector& GetTrainPreds() {return outcome_preds_train_;} - std::vector& GetTestPreds() {return outcome_preds_test_;} + std::vector& GetOutcomeTrainPreds() {return outcome_preds_train_;} + std::vector& GetOutcomeTestPreds() {return outcome_preds_test_;} + std::vector& GetRFXTrainPreds() {return rfx_preds_train_;} + std::vector& GetRFXTestPreds() {return rfx_preds_test_;} + std::vector& GetForestTrainPreds() {return forest_preds_train_;} + std::vector& GetForestTestPreds() {return forest_preds_test_;} std::vector& GetGlobalVarianceSamples() {return sigma_samples_;} std::vector& GetLeafVarianceSamples() {return tau_samples_;} int NumGFRSamples() {return num_gfr_;} @@ -47,6 +51,10 @@ class BARTResult { std::unique_ptr rfx_label_mapper_; std::vector outcome_preds_train_; std::vector outcome_preds_test_; + std::vector rfx_preds_train_; + std::vector rfx_preds_test_; + std::vector forest_preds_train_; + std::vector forest_preds_test_; std::vector sigma_samples_; std::vector tau_samples_; int num_gfr_{0}; @@ -98,7 +106,11 @@ class BARTDispatcher { } } - void AddRFXTerm(double* rfx_basis, std::vector& rfx_group_indices, data_size_t num_row, int num_groups, int num_basis, bool is_row_major, bool train) { + void AddRFXTerm(double* rfx_basis, std::vector& rfx_group_indices, data_size_t num_row, + int num_groups, int num_basis, bool is_row_major, bool train, + Eigen::VectorXd& alpha_init, Eigen::MatrixXd& xi_init, + Eigen::MatrixXd& sigma_alpha_init, Eigen::MatrixXd& sigma_xi_init, + double sigma_xi_shape, double sigma_xi_scale) { if (train) { rfx_train_dataset_ = RandomEffectsDataset(); rfx_train_dataset_.AddBasis(rfx_basis, num_row, num_basis, is_row_major); @@ -108,6 +120,12 @@ class BARTDispatcher { num_rfx_groups_ = num_groups; num_rfx_basis_ = num_basis; has_rfx_ = true; + rfx_model_.SetWorkingParameter(alpha_init); + rfx_model_.SetGroupParameters(xi_init); + rfx_model_.SetWorkingParameterCovariance(sigma_alpha_init); + rfx_model_.SetGroupParameterCovariance(sigma_xi_init); + rfx_model_.SetVariancePriorShape(sigma_xi_shape); + rfx_model_.SetVariancePriorScale(sigma_xi_scale); } else { rfx_test_dataset_ = RandomEffectsDataset(); rfx_test_dataset_.AddBasis(rfx_basis, num_row, num_basis, is_row_major); @@ -144,18 +162,50 @@ class BARTDispatcher { // Obtain references to forest / parameter samples and predictions in BARTResult ForestContainer* forest_samples = output.GetForests(); + RandomEffectsContainer* rfx_container = output.GetRFXContainer(); + LabelMapper* label_mapper = output.GetRFXLabelMapper(); std::vector& sigma2_samples = output.GetGlobalVarianceSamples(); std::vector& tau_samples = output.GetLeafVarianceSamples(); - std::vector& train_preds = output.GetTrainPreds(); - std::vector& test_preds = output.GetTestPreds(); + std::vector& forest_train_preds = output.GetForestTrainPreds(); + std::vector& forest_test_preds = output.GetForestTestPreds(); + std::vector& rfx_train_preds = output.GetRFXTrainPreds(); + std::vector& rfx_test_preds = output.GetRFXTestPreds(); + std::vector& outcome_train_preds = output.GetOutcomeTrainPreds(); + std::vector& outcome_test_preds = output.GetOutcomeTestPreds(); + + // Update RFX output containers + if (has_rfx_) { + rfx_container->Initialize(num_rfx_basis_, num_rfx_groups_); + label_mapper->Initialize(rfx_tracker_.GetLabelMap()); + } // Clear and prepare vectors to store results - sigma2_samples.clear(); - train_preds.clear(); - test_preds.clear(); - sigma2_samples.resize(num_samples); - train_preds.resize(num_samples*num_train_); - if (has_test_set_) test_preds.resize(num_samples*num_test_); + forest_train_preds.clear(); + forest_train_preds.resize(num_samples*num_train_); + outcome_train_preds.clear(); + outcome_train_preds.resize(num_samples*num_train_); + if (has_test_set_) { + forest_test_preds.clear(); + forest_test_preds.resize(num_samples*num_test_); + outcome_test_preds.clear(); + outcome_test_preds.resize(num_samples*num_test_); + } + if (sample_global_var) { + sigma2_samples.clear(); + sigma2_samples.resize(num_samples); + } + if (sample_leaf_var) { + tau_samples.clear(); + tau_samples.resize(num_samples); + } + if (has_rfx_) { + rfx_train_preds.clear(); + rfx_train_preds.resize(num_samples*num_train_); + if (has_test_set_) { + rfx_test_preds.clear(); + rfx_test_preds.resize(num_samples*num_test_); + } + } // Initialize tracker and tree prior ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_); @@ -166,6 +216,12 @@ class BARTDispatcher { // Initialize leaf variance model LeafNodeHomoskedasticVarianceModel leaf_var_model = LeafNodeHomoskedasticVarianceModel(); + double leaf_var; + if (sample_leaf_var) { + CHECK_EQ(leaf_cov_init.rows(),1); + CHECK_EQ(leaf_cov_init.cols(),1); + leaf_var = leaf_cov_init(0,0); + } // Initialize leaf model and samplers // TODO: add template specialization for GaussianMultivariateRegressionLeafModel which takes Eigen::MatrixXd& @@ -176,6 +232,7 @@ class BARTDispatcher { // Running variable for current sampled value of global outcome variance parameter double global_var = global_var_init; + Eigen::MatrixXd leaf_cov = leaf_cov_init; // Run the XBART Gibbs sampler int iter = 0; @@ -186,10 +243,18 @@ class BARTDispatcher { rng, variable_weights, global_var, feature_types, false); if (sample_global_var) { - // Sample the global outcome + // Sample the global outcome variance global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); sigma2_samples.at(iter) = global_var; } + + if (sample_leaf_var) { + // Sample the leaf model variance + TreeEnsemble* ensemble = forest_samples->GetEnsemble(iter); + leaf_var = leaf_var_model.SampleVarianceParameter(ensemble, a_leaf, b_leaf, rng); + tau_samples.at(iter) = leaf_var; + leaf_cov(0,0) = leaf_var; + } // Increment sample counter iter++; @@ -204,19 +269,42 @@ class BARTDispatcher { rng, variable_weights, global_var, true); if (sample_global_var) { - // Sample the global outcome + // Sample the global outcome variance global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); sigma2_samples.at(iter) = global_var; } + + if (sample_leaf_var) { + // Sample the leaf model variance + TreeEnsemble* ensemble = forest_samples->GetEnsemble(iter); + leaf_var = leaf_var_model.SampleVarianceParameter(ensemble, a_leaf, b_leaf, rng); + tau_samples.at(iter) = leaf_var; + leaf_cov(0,0) = leaf_var; + } // Increment sample counter iter++; } } - // Predict forests - forest_samples->PredictInPlace(train_dataset_, train_preds); - if (has_test_set_) forest_samples->PredictInPlace(test_dataset_, test_preds); + // Predict forests and rfx + forest_samples->PredictInPlace(train_dataset_, forest_train_preds); + if (has_test_set_) forest_samples->PredictInPlace(test_dataset_, forest_test_preds); + if (has_rfx_) { + rfx_container->Predict(rfx_train_dataset_, *label_mapper, rfx_train_preds); + for (data_size_t ind = 0; ind < rfx_train_preds.size(); ind++) { + outcome_train_preds.at(ind) = rfx_train_preds.at(ind) + forest_train_preds.at(ind); + } + if (has_test_set_) { + rfx_container->Predict(rfx_test_dataset_, *label_mapper, rfx_test_preds); + for (data_size_t ind = 0; ind < rfx_test_preds.size(); ind++) { + outcome_test_preds.at(ind) = rfx_test_preds.at(ind) + forest_test_preds.at(ind); + } + } + } else { + forest_samples->PredictInPlace(train_dataset_, outcome_train_preds); + if (has_test_set_) forest_samples->PredictInPlace(test_dataset_, outcome_test_preds); + } } private: diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index e3709d98..9ae2721c 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -86,6 +86,7 @@ class GaussianConstantLeafModel { void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); void SetScale(double tau) {tau_ = tau;} + void SetScale(Eigen::MatrixXd& tau) {tau_ = tau(0,0);} inline bool RequiresBasis() {return false;} private: double tau_; @@ -158,6 +159,7 @@ class GaussianUnivariateRegressionLeafModel { void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); void SetScale(double tau) {tau_ = tau;} + void SetScale(Eigen::MatrixXd& tau) {tau_ = tau(0,0);} inline bool RequiresBasis() {return true;} private: double tau_; diff --git a/include/stochtree/random_effects.h b/include/stochtree/random_effects.h index 5ece8b1a..3c0e5b76 100644 --- a/include/stochtree/random_effects.h +++ b/include/stochtree/random_effects.h @@ -67,11 +67,17 @@ class RandomEffectsTracker { class LabelMapper { public: LabelMapper() {} - LabelMapper(std::map label_map) { + LabelMapper(std::map& label_map) { label_map_ = label_map; for (const auto& [key, value] : label_map) keys_.push_back(key); } ~LabelMapper() {} + void Initialize(std::map& label_map) { + label_map_.clear(); + keys_.clear(); + label_map_ = label_map; + for (const auto& [key, value] : label_map) keys_.push_back(key); + } bool ContainsLabel(int32_t category_id) { auto pos = label_map_.find(category_id); return pos != label_map_.end(); @@ -276,6 +282,11 @@ class RandomEffectsContainer { num_samples_ = 0; } ~RandomEffectsContainer() {} + void Initialize(int num_components, int num_groups) { + num_components_ = num_components; + num_groups_ = num_groups; + num_samples_ = 0; + } void AddSample(MultivariateRegressionRandomEffectsModel& model); void Predict(RandomEffectsDataset& dataset, LabelMapper& label_mapper, std::vector& output); int NumSamples() {return num_samples_;} diff --git a/man/bart_specialized.Rd b/man/bart_specialized.Rd index d270d650..ed44bf1f 100644 --- a/man/bart_specialized.Rd +++ b/man/bart_specialized.Rd @@ -7,12 +7,19 @@ bart_specialized( X_train, y_train, + W_train = NULL, + group_ids_train = NULL, + rfx_basis_train = NULL, X_test = NULL, + W_test = NULL, + group_ids_test = NULL, + rfx_basis_test = NULL, cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, beta = 2, min_samples_leaf = 5, + leaf_model = 0, nu = 3, lambda = NULL, a_leaf = 3, @@ -24,10 +31,14 @@ bart_specialized( num_gfr = 5, num_burnin = 0, num_mcmc = 100, + sample_sigma = T, + sample_tau = T, random_seed = -1, keep_burnin = F, keep_gfr = F, - verbose = F + verbose = F, + sample_global_var = T, + sample_leaf_var = F ) } \arguments{ @@ -39,10 +50,31 @@ that the column is ordered categorical).} \item{y_train}{Outcome to be modeled by the ensemble.} +\item{W_train}{(Optional) Bases used to define a regression model \code{y ~ W} in +each leaf of each regression tree. By default, BART assumes constant leaf node +parameters, implicitly regressing on a constant basis of ones (i.e. \code{y ~ 1}).} + +\item{group_ids_train}{(Optional) Group labels used for an additive random effects model.} + +\item{rfx_basis_train}{(Optional) Basis for "random-slope" regression in an additive random effects model. +If \code{group_ids_train} is provided with a regression basis, an intercept-only random effects model +will be estimated.} + \item{X_test}{(Optional) Test set of covariates used to define "out of sample" evaluation data. May be provided either as a dataframe or a matrix, but the format of \code{X_test} must be consistent with that of \code{X_train}.} +\item{W_test}{(Optional) Test set of bases used to define "out of sample" evaluation data. +While a test set is optional, the structure of any provided test set must match that +of the training set (i.e. if both X_train and W_train are provided, then a test set must +consist of X_test and W_test with the same number of columns).} + +\item{group_ids_test}{(Optional) Test set group labels used for an additive random effects model. +We do not currently support (but plan to in the near future), test set evaluation for group labels +that were not in the training set.} + +\item{rfx_basis_test}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} + \item{cutpoint_grid_size}{Maximum size of the "grid" of potential cutpoints to consider. Default: 100.} \item{tau_init}{Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here.} @@ -53,6 +85,8 @@ that of \code{X_train}.} \item{min_samples_leaf}{Minimum allowable size of a leaf, in terms of training samples. Default: 5.} +\item{leaf_model}{Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.} + \item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.} \item{lambda}{Component of the scale parameter in the \code{IG(nu, nu*lambda)} global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).} @@ -75,6 +109,10 @@ that of \code{X_train}.} \item{num_mcmc}{Number of "retained" iterations of the MCMC sampler. Default: 100.} +\item{sample_sigma}{Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(nu, nu*lambda)}. Default: T.} + +\item{sample_tau}{Whether or not to update the \code{tau} leaf scale variance parameter based on \code{IG(a_leaf, b_leaf)}. Cannot (currently) be set to true if \code{ncol(W_train)>1}. Default: T.} + \item{random_seed}{Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}.} \item{keep_burnin}{Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.} @@ -83,7 +121,9 @@ that of \code{X_train}.} \item{verbose}{Whether or not to print progress during the sampling loops. Default: FALSE.} -\item{leaf_model}{Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.} +\item{sample_global_var}{Whether or not global variance parameter should be sampled. Default: TRUE.} + +\item{sample_leaf_var}{Whether or not leaf model variance parameter should be sampled. Default: FALSE.} } \value{ List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). diff --git a/src/R_bart.cpp b/src/R_bart.cpp index 60127f1b..6c7e1ab2 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -13,13 +14,23 @@ #include [[cpp11::register]] -cpp11::external_pointer run_bart_cpp( - cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, - cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, - int output_dimension, bool is_leaf_constant, double alpha, double beta, - int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, - double nu, double lamb, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, - int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int +cpp11::external_pointer run_bart_cpp_basis_test_rfx( + cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, int num_basis_train, + cpp11::doubles covariates_test, cpp11::doubles basis_test, + int num_rows_test, int num_covariates_test, int num_basis_test, + cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, + int num_rfx_basis_train, int num_rfx_groups_train, + cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, + int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, + cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var, + cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, + cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, + double rfx_sigma_xi_shape, double rfx_sigma_xi_scale ) { // Create smart pointer to newly allocated object std::unique_ptr bart_result_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); @@ -47,49 +58,1030 @@ cpp11::external_pointer run_bart_cpp( } } + // Check inputs + if (num_covariates_train != num_covariates_test) { + StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test"); + } + if (num_basis_train != num_basis_test) { + StochTree::Log::Fatal("num_basis_train must equal num_basis_test"); + } + if (num_rfx_basis_train != num_rfx_basis_test) { + StochTree::Log::Fatal("num_rfx_basis_train must equal num_rfx_basis_test"); + } + if (num_rfx_groups_train != num_rfx_groups_test) { + StochTree::Log::Fatal("num_rfx_groups_train must equal num_rfx_groups_test"); + } + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Convert rfx group IDs to std::vector + std::vector rfx_group_labels_train_cpp; + std::vector rfx_group_labels_test_cpp; + rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size()); + for (int i = 0; i < rfx_group_labels_train.size(); i++) { + rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i); + } + rfx_group_labels_test_cpp.resize(rfx_group_labels_test.size()); + for (int i = 0; i < rfx_group_labels_test.size(); i++) { + rfx_group_labels_test_cpp.at(i) = rfx_group_labels_test.at(i); + } + + // Unpack RFX terms + Eigen::VectorXd alpha_init; + Eigen::MatrixXd xi_init; + Eigen::MatrixXd sigma_alpha_init; + Eigen::MatrixXd sigma_xi_init; + double sigma_xi_shape; + double sigma_xi_scale; + alpha_init.resize(rfx_alpha_init.size()); + xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol()); + sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol()); + sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol()); + for (int i = 0; i < rfx_alpha_init.size(); i++) { + alpha_init(i) = rfx_alpha_init.at(i); + } + for (int i = 0; i < rfx_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_xi_init.ncol(); j++) { + xi_init(i,j) = rfx_xi_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) { + sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) { + sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j); + } + } + sigma_xi_shape = rfx_sigma_xi_shape; + sigma_xi_scale = rfx_sigma_xi_scale; + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_basis_data_ptr = REAL(PROTECT(basis_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* test_covariate_data_ptr = REAL(PROTECT(covariates_test)); + double* test_basis_data_ptr = REAL(PROTECT(basis_test)); + double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train)); + double* test_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_test)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } + + // Unprotect pointers to R data + UNPROTECT(7); + + // Release management of the pointer to R session + return cpp11::external_pointer(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer run_bart_cpp_basis_test_norfx( + cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, int num_basis_train, + cpp11::doubles covariates_test, cpp11::doubles basis_test, + int num_rows_test, int num_covariates_test, int num_basis_test, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var +) { + // Create smart pointer to newly allocated object + std::unique_ptr bart_result_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + if (num_covariates_train != num_covariates_test) { + StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test"); + } + if (num_basis_train != num_basis_test) { + StochTree::Log::Fatal("num_basis_train must equal num_basis_test"); + } + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + // Create BART dispatcher and add data - double* covariate_data_ptr = REAL(PROTECT(covariates)); - double* outcome_data_ptr = REAL(PROTECT(outcome)); + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_basis_data_ptr = REAL(PROTECT(basis_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* test_covariate_data_ptr = REAL(PROTECT(covariates_test)); + double* test_basis_data_ptr = REAL(PROTECT(basis_test)); if (leaf_model_int == 0) { // Create the dispatcher and load the data StochTree::BARTDispatcher bart_dispatcher{}; - bart_dispatcher.AddDataset(covariate_data_ptr, num_rows, num_covariates, false, true); - bart_dispatcher.AddTrainOutcome(outcome_data_ptr, num_rows); + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); // Run the sampling loop bart_dispatcher.RunSampler( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - true, false, -1 + sample_global_var, sample_leaf_var, random_seed ); } else if (leaf_model_int == 1) { // Create the dispatcher and load the data StochTree::BARTDispatcher bart_dispatcher{}; - bart_dispatcher.AddDataset(covariate_data_ptr, num_rows, num_covariates, false, true); - bart_dispatcher.AddTrainOutcome(outcome_data_ptr, num_rows); + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, test_basis_data_ptr, num_rows_test, num_covariates_test, num_basis_test, false, false); // Run the sampling loop bart_dispatcher.RunSampler( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, - alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - true, false, -1 + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed ); } - // // TODO: Figure out dispatch here - // else { - // // Create the dispatcher and load the data - // StochTree::BARTDispatcher bart_dispatcher{}; - // bart_dispatcher.AddDataset(covariate_data_ptr, num_rows, num_covariates, false, true); - // bart_dispatcher.AddTrainOutcome(outcome_data_ptr, num_rows); - // // Run the sampling loop - // bart_dispatcher.RunSampler( - // *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, - // num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, - // alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - // true, false, -1 - // ); + + // Unprotect pointers to R data + UNPROTECT(5); + + // Release management of the pointer to R session + return cpp11::external_pointer(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer run_bart_cpp_basis_notest_rfx( + cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, int num_basis_train, + cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, + int num_rfx_basis_train, int num_rfx_groups_train, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var, + cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, + cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, + double rfx_sigma_xi_shape, double rfx_sigma_xi_scale +) { + // Create smart pointer to newly allocated object + std::unique_ptr bart_result_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); // } + // Convert rfx group IDs to std::vector + std::vector rfx_group_labels_train_cpp; + rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size()); + for (int i = 0; i < rfx_group_labels_train.size(); i++) { + rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i); + } + + // Unpack RFX terms + Eigen::VectorXd alpha_init; + Eigen::MatrixXd xi_init; + Eigen::MatrixXd sigma_alpha_init; + Eigen::MatrixXd sigma_xi_init; + double sigma_xi_shape; + double sigma_xi_scale; + alpha_init.resize(rfx_alpha_init.size()); + xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol()); + sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol()); + sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol()); + for (int i = 0; i < rfx_alpha_init.size(); i++) { + alpha_init(i) = rfx_alpha_init.at(i); + } + for (int i = 0; i < rfx_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_xi_init.ncol(); j++) { + xi_init(i,j) = rfx_xi_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) { + sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) { + sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j); + } + } + sigma_xi_shape = rfx_sigma_xi_shape; + sigma_xi_scale = rfx_sigma_xi_scale; + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_basis_data_ptr = REAL(PROTECT(basis_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } + + // Unprotect pointers to R data + UNPROTECT(4); + + // Release management of the pointer to R session + return cpp11::external_pointer(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer run_bart_cpp_basis_notest_norfx( + cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, int num_basis_train, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var +) { + // Create smart pointer to newly allocated object + std::unique_ptr bart_result_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_basis_data_ptr = REAL(PROTECT(basis_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, train_basis_data_ptr, num_rows_train, num_covariates_train, num_basis_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } + + // Unprotect pointers to R data + UNPROTECT(3); + + // Release management of the pointer to R session + return cpp11::external_pointer(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer run_bart_cpp_nobasis_test_rfx( + cpp11::doubles covariates_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, + cpp11::doubles covariates_test, + int num_rows_test, int num_covariates_test, + cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, + int num_rfx_basis_train, int num_rfx_groups_train, + cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, + int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, + cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var, + cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, + cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, + double rfx_sigma_xi_shape, double rfx_sigma_xi_scale +) { + // Create smart pointer to newly allocated object + std::unique_ptr bart_result_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + if (num_covariates_train != num_covariates_test) { + StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test"); + } + if (num_rfx_basis_train != num_rfx_basis_test) { + StochTree::Log::Fatal("num_rfx_basis_train must equal num_rfx_basis_test"); + } + if (num_rfx_groups_train != num_rfx_groups_test) { + StochTree::Log::Fatal("num_rfx_groups_train must equal num_rfx_groups_test"); + } + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Convert rfx group IDs to std::vector + std::vector rfx_group_labels_train_cpp; + std::vector rfx_group_labels_test_cpp; + rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size()); + for (int i = 0; i < rfx_group_labels_train.size(); i++) { + rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i); + } + rfx_group_labels_test_cpp.resize(rfx_group_labels_test.size()); + for (int i = 0; i < rfx_group_labels_test.size(); i++) { + rfx_group_labels_test_cpp.at(i) = rfx_group_labels_test.at(i); + } + + // Unpack RFX terms + Eigen::VectorXd alpha_init; + Eigen::MatrixXd xi_init; + Eigen::MatrixXd sigma_alpha_init; + Eigen::MatrixXd sigma_xi_init; + double sigma_xi_shape; + double sigma_xi_scale; + alpha_init.resize(rfx_alpha_init.size()); + xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol()); + sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol()); + sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol()); + for (int i = 0; i < rfx_alpha_init.size(); i++) { + alpha_init(i) = rfx_alpha_init.at(i); + } + for (int i = 0; i < rfx_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_xi_init.ncol(); j++) { + xi_init(i,j) = rfx_xi_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) { + sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) { + sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j); + } + } + sigma_xi_shape = rfx_sigma_xi_shape; + sigma_xi_scale = rfx_sigma_xi_scale; + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* test_covariate_data_ptr = REAL(PROTECT(covariates_test)); + double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train)); + double* test_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_test)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + bart_dispatcher.AddRFXTerm(test_rfx_basis_data_ptr, rfx_group_labels_test_cpp, num_rows_test, + num_rfx_groups_test, num_rfx_basis_test, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } + + // Unprotect pointers to R data + UNPROTECT(5); + + // Release management of the pointer to R session + return cpp11::external_pointer(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer run_bart_cpp_nobasis_test_norfx( + cpp11::doubles covariates_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, + cpp11::doubles covariates_test, + int num_rows_test, int num_covariates_test, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var +) { + // Create smart pointer to newly allocated object + std::unique_ptr bart_result_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + if (num_covariates_train != num_covariates_test) { + StochTree::Log::Fatal("num_covariates_train must equal num_covariates_test"); + } + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* test_covariate_data_ptr = REAL(PROTECT(covariates_test)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load test data + bart_dispatcher.AddDataset(test_covariate_data_ptr, num_rows_test, num_covariates_test, false, false); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } + + // Unprotect pointers to R data + UNPROTECT(3); + + // Release management of the pointer to R session + return cpp11::external_pointer(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer run_bart_cpp_nobasis_notest_rfx( + cpp11::doubles covariates_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, + cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, + int num_rfx_basis_train, int num_rfx_groups_train, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var, + cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, + cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, + double rfx_sigma_xi_shape, double rfx_sigma_xi_scale +) { + // Create smart pointer to newly allocated object + std::unique_ptr bart_result_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Convert rfx group IDs to std::vector + std::vector rfx_group_labels_train_cpp; + rfx_group_labels_train_cpp.resize(rfx_group_labels_train.size()); + for (int i = 0; i < rfx_group_labels_train.size(); i++) { + rfx_group_labels_train_cpp.at(i) = rfx_group_labels_train.at(i); + } + + // Unpack RFX terms + Eigen::VectorXd alpha_init; + Eigen::MatrixXd xi_init; + Eigen::MatrixXd sigma_alpha_init; + Eigen::MatrixXd sigma_xi_init; + double sigma_xi_shape; + double sigma_xi_scale; + alpha_init.resize(rfx_alpha_init.size()); + xi_init.resize(rfx_xi_init.nrow(), rfx_xi_init.ncol()); + sigma_alpha_init.resize(rfx_sigma_alpha_init.nrow(), rfx_sigma_alpha_init.ncol()); + sigma_xi_init.resize(rfx_sigma_xi_init.nrow(), rfx_sigma_xi_init.ncol()); + for (int i = 0; i < rfx_alpha_init.size(); i++) { + alpha_init(i) = rfx_alpha_init.at(i); + } + for (int i = 0; i < rfx_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_xi_init.ncol(); j++) { + xi_init(i,j) = rfx_xi_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_alpha_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_alpha_init.ncol(); j++) { + sigma_alpha_init(i,j) = rfx_sigma_alpha_init(i,j); + } + } + for (int i = 0; i < rfx_sigma_xi_init.nrow(); i++) { + for (int j = 0; j < rfx_sigma_xi_init.ncol(); j++) { + sigma_xi_init(i,j) = rfx_sigma_xi_init(i,j); + } + } + sigma_xi_shape = rfx_sigma_xi_shape; + sigma_xi_scale = rfx_sigma_xi_scale; + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + double* train_rfx_basis_data_ptr = REAL(PROTECT(rfx_basis_train)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Load rfx data + bart_dispatcher.AddRFXTerm(train_rfx_basis_data_ptr, rfx_group_labels_train_cpp, num_rows_train, + num_rfx_groups_train, num_rfx_basis_train, false, true, alpha_init, + xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } + + // Unprotect pointers to R data + UNPROTECT(3); + + // Release management of the pointer to R session + return cpp11::external_pointer(bart_result_ptr_.release()); +} + +[[cpp11::register]] +cpp11::external_pointer run_bart_cpp_nobasis_notest_norfx( + cpp11::doubles covariates_train, cpp11::doubles outcome_train, + int num_rows_train, int num_covariates_train, + cpp11::integers feature_types, cpp11::doubles variable_weights, + int num_trees, int output_dimension, bool is_leaf_constant, + double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, + int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, + int leaf_model_int, bool sample_global_var, bool sample_leaf_var +) { + // Create smart pointer to newly allocated object + std::unique_ptr bart_result_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast(feature_types[i]); + } + + // Convert leaf covariance to Eigen::MatrixXd + int leaf_dim = leaf_cov_init.nrow(); + Eigen::MatrixXd leaf_cov(leaf_cov_init.nrow(), leaf_cov_init.ncol()); + for (int i = 0; i < leaf_cov_init.nrow(); i++) { + leaf_cov(i,i) = leaf_cov_init(i,i); + for (int j = 0; j < i; j++) { + leaf_cov(i,j) = leaf_cov_init(i,j); + leaf_cov(j,i) = leaf_cov_init(j,i); + } + } + + // Check inputs + // if ((leaf_model_int == 1) || (leaf_model_int == 2)) { + // StochTree::Log::Fatal("Must provide basis for leaf regression"); + // } + + // Create BART dispatcher and add data + double* train_covariate_data_ptr = REAL(PROTECT(covariates_train)); + double* train_outcome_data_ptr = REAL(PROTECT(outcome_train)); + if (leaf_model_int == 0) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else if (leaf_model_int == 1) { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } else { + // Create the dispatcher and load the data + StochTree::BARTDispatcher bart_dispatcher{}; + // Load training data + bart_dispatcher.AddDataset(train_covariate_data_ptr, num_rows_train, num_covariates_train, false, true); + bart_dispatcher.AddTrainOutcome(train_outcome_data_ptr, num_rows_train); + // Run the sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + sample_global_var, sample_leaf_var, random_seed + ); + } + // Unprotect pointers to R data UNPROTECT(2); diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 73eab2cd..39edd378 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -6,10 +6,59 @@ #include // R_bart.cpp -cpp11::external_pointer run_bart_cpp(cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, double nu, double lamb, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int); -extern "C" SEXP _stochtree_run_bart_cpp(SEXP covariates, SEXP outcome, SEXP feature_types, SEXP variable_weights, SEXP num_rows, SEXP num_covariates, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int) { +cpp11::external_pointer run_bart_cpp_basis_test_rfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles covariates_test, cpp11::doubles basis_test, int num_rows_test, int num_covariates_test, int num_basis_test, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_basis_test_rfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP covariates_test, SEXP basis_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP num_basis_test, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP rfx_basis_test, SEXP rfx_group_labels_test, SEXP num_rfx_basis_test, SEXP num_rfx_groups_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { BEGIN_CPP11 - return cpp11::as_sexp(run_bart_cpp(cpp11::as_cpp>(covariates), cpp11::as_cpp>(outcome), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_rows), cpp11::as_cpp>(num_covariates), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int))); + return cpp11::as_sexp(run_bart_cpp_basis_test_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(basis_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(num_basis_test), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(rfx_basis_test), cpp11::as_cpp>(rfx_group_labels_test), cpp11::as_cpp>(num_rfx_basis_test), cpp11::as_cpp>(num_rfx_groups_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer run_bart_cpp_basis_test_norfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles covariates_test, cpp11::doubles basis_test, int num_rows_test, int num_covariates_test, int num_basis_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_basis_test_norfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP covariates_test, SEXP basis_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP num_basis_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_basis_test_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(basis_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(num_basis_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer run_bart_cpp_basis_notest_rfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_basis_notest_rfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_basis_notest_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer run_bart_cpp_basis_notest_norfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_basis_notest_norfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_basis_notest_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer run_bart_cpp_nobasis_test_rfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles covariates_test, int num_rows_test, int num_covariates_test, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_test_rfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP covariates_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP rfx_basis_test, SEXP rfx_group_labels_test, SEXP num_rfx_basis_test, SEXP num_rfx_groups_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_nobasis_test_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(rfx_basis_test), cpp11::as_cpp>(rfx_group_labels_test), cpp11::as_cpp>(num_rfx_basis_test), cpp11::as_cpp>(num_rfx_groups_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer run_bart_cpp_nobasis_test_norfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles covariates_test, int num_rows_test, int num_covariates_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_test_norfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP covariates_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_nobasis_test_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer run_bart_cpp_nobasis_notest_rfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_notest_rfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_nobasis_notest_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); + END_CPP11 +} +// R_bart.cpp +cpp11::external_pointer run_bart_cpp_nobasis_notest_norfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_notest_norfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_cpp_nobasis_notest_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); END_CPP11 } // R_data.cpp @@ -876,7 +925,14 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, - {"_stochtree_run_bart_cpp", (DL_FUNC) &_stochtree_run_bart_cpp, 24}, + {"_stochtree_run_bart_cpp_basis_notest_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_norfx, 28}, + {"_stochtree_run_bart_cpp_basis_notest_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_rfx, 38}, + {"_stochtree_run_bart_cpp_basis_test_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_norfx, 33}, + {"_stochtree_run_bart_cpp_basis_test_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_rfx, 47}, + {"_stochtree_run_bart_cpp_nobasis_notest_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_norfx, 26}, + {"_stochtree_run_bart_cpp_nobasis_notest_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_rfx, 36}, + {"_stochtree_run_bart_cpp_nobasis_test_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_norfx, 29}, + {"_stochtree_run_bart_cpp_nobasis_test_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_rfx, 43}, {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 13}, {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 13}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 4}, diff --git a/src/stochtree_types.h b/src/stochtree_types.h index 584f19b4..6569badb 100644 --- a/src/stochtree_types.h +++ b/src/stochtree_types.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include From 698ec7018bdc2aaf14b1cd0e6dc0006f107a1922 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 18 Jul 2024 02:08:14 -0400 Subject: [PATCH 09/18] Updated BART function --- R/bart.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/bart.R b/R/bart.R index 25fabd7d..8cfe6533 100644 --- a/R/bart.R +++ b/R/bart.R @@ -747,6 +747,7 @@ bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = } else { num_rows_test <- 0 } + num_samples <- num_gfr + num_burnin + num_mcmc # Update variable weights variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x)) From 44237dbe9787342571a1722bde570c64e5ac4177 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 19 Jul 2024 01:05:48 -0400 Subject: [PATCH 10/18] Added max_depth control to MCMC and GFR samplers --- R/bart.R | 26 +-- R/bcf.R | 19 +- R/cpp11.R | 36 ++-- R/model.R | 10 +- debug/api_debug.cpp | 3 +- include/stochtree/bart.h | 4 +- include/stochtree/prior.h | 6 +- include/stochtree/tree_sampler.h | 305 ++++++++++++++++--------------- man/ForestModel.Rd | 5 +- man/bart.Rd | 3 + man/bart_specialized.Rd | 3 + man/bcf.Rd | 6 + man/createForestModel.Rd | 5 +- src/R_bart.cpp | 64 +++---- src/cpp11.cpp | 72 ++++---- src/py_stochtree.cpp | 4 +- src/sampler.cpp | 4 +- stochtree/bart.py | 6 +- stochtree/bcf.py | 12 +- stochtree/sampler.py | 4 +- test/cpp/test_model.cpp | 12 +- 21 files changed, 334 insertions(+), 275 deletions(-) diff --git a/R/bart.R b/R/bart.R index 8cfe6533..cc2cd1ab 100644 --- a/R/bart.R +++ b/R/bart.R @@ -30,6 +30,7 @@ #' @param beta Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. #' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0. #' @param min_samples_leaf Minimum allowable size of a leaf, in terms of training samples. Default: 5. +#' @param max_depth Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees. #' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3. #' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). #' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3. @@ -79,7 +80,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, rfx_basis_train = NULL, X_test = NULL, W_test = NULL, group_ids_test = NULL, rfx_basis_test = NULL, cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, - beta = 2.0, min_samples_leaf = 5, leaf_model = 0, + beta = 2.0, min_samples_leaf = 5, max_depth = 10, leaf_model = 0, nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, q = 0.9, sigma2_init = NULL, variable_weights = NULL, num_trees = 200, num_gfr = 5, num_burnin = 0, @@ -130,7 +131,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { rfx_basis_test <- as.matrix(rfx_basis_test) } - + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- F has_rfx_test <- F @@ -273,7 +274,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, # Sampling data structures feature_types <- as.integer(feature_types) - forest_model <- createForestModel(forest_dataset_train, feature_types, num_trees, nrow(X_train), alpha, beta, min_samples_leaf) + forest_model <- createForestModel(forest_dataset_train, feature_types, num_trees, nrow(X_train), alpha, beta, min_samples_leaf, max_depth) # Container of forest samples forest_samples <- createForestContainer(num_trees, output_dimension, is_leaf_constant) @@ -654,6 +655,7 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL #' @param beta Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. #' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0. #' @param min_samples_leaf Minimum allowable size of a leaf, in terms of training samples. Default: 5. +#' @param max_depth Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees. #' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3. #' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). #' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3. @@ -705,7 +707,7 @@ bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = rfx_basis_train = NULL, X_test = NULL, W_test = NULL, group_ids_test = NULL, rfx_basis_test = NULL, cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, - beta = 2.0, min_samples_leaf = 5, leaf_model = 0, + beta = 2.0, min_samples_leaf = 5, max_depth = 10, leaf_model = 0, nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, q = 0.9, sigma2_init = NULL, variable_weights = NULL, num_trees = 200, num_gfr = 5, num_burnin = 0, @@ -929,7 +931,7 @@ bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, as.numeric(rfx_basis_test), group_ids_test, num_basis_rfx, num_rfx_groups, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, - alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale @@ -940,7 +942,7 @@ bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = num_rows_train, num_cov_train, num_basis_train, as.numeric(X_test), as.numeric(W_test), num_rows_test, num_cov_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, - alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, sample_global_var, sample_leaf_var ) @@ -950,7 +952,7 @@ bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = num_rows_train, num_cov_train, num_basis_train, as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, - alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale @@ -960,7 +962,7 @@ bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = as.numeric(X_train), as.numeric(W_train), resid_train, num_rows_train, num_cov_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, - alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, sample_global_var, sample_leaf_var ) @@ -972,7 +974,7 @@ bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, as.numeric(rfx_basis_test), group_ids_test, num_basis_rfx, num_rfx_groups, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, - alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale @@ -983,7 +985,7 @@ bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = num_rows_train, num_cov_train, as.numeric(X_test), num_rows_test, num_cov_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, - alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, sample_global_var, sample_leaf_var ) @@ -993,7 +995,7 @@ bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = num_rows_train, num_cov_train, as.numeric(rfx_basis_train), group_ids_train, num_basis_rfx, num_rfx_groups, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, - alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, sample_global_var, sample_leaf_var, alpha_init, xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale @@ -1003,7 +1005,7 @@ bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = as.numeric(X_train), resid_train, num_rows_train, num_cov_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, - alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, cutpoint_grid_size, + alpha, beta, a_leaf, b_leaf, nu, lambda, min_samples_leaf, max_depth, cutpoint_grid_size, tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model, sample_global_var, sample_leaf_var ) diff --git a/R/bcf.R b/R/bcf.R index 3425e96b..42b48b39 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -30,6 +30,8 @@ #' @param beta_tau Exponent that decreases split probabilities for nodes of depth > 0 for the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: 3.0. #' @param min_samples_leaf_mu Minimum allowable size of a leaf, in terms of training samples, for the prognostic forest. Default: 5. #' @param min_samples_leaf_tau Minimum allowable size of a leaf, in terms of training samples, for the treatment effect forest. Default: 5. +#' @param max_depth_mu Maximum depth of any tree in the mu ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees. +#' @param max_depth_tau Maximum depth of any tree in the tau ensemble. Default: 5. Can be overriden with ``-1`` which does not enforce any depth limits on trees. #' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3. #' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). #' @param a_leaf_mu Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the prognostic forest. Default: 3. @@ -116,12 +118,13 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU group_ids_test = NULL, rfx_basis_test = NULL, cutpoint_grid_size = 100, sigma_leaf_mu = NULL, sigma_leaf_tau = NULL, alpha_mu = 0.95, alpha_tau = 0.25, beta_mu = 2.0, beta_tau = 3.0, min_samples_leaf_mu = 5, min_samples_leaf_tau = 5, - nu = 3, lambda = NULL, a_leaf_mu = 3, a_leaf_tau = 3, b_leaf_mu = NULL, b_leaf_tau = NULL, - q = 0.9, sigma2 = NULL, variable_weights = NULL, keep_vars_mu = NULL, drop_vars_mu = NULL, - keep_vars_tau = NULL, drop_vars_tau = NULL, num_trees_mu = 250, num_trees_tau = 50, - num_gfr = 5, num_burnin = 0, num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T, - sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, - b_1 = 0.5, rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F) { + max_depth_mu = 10, max_depth_tau = 5, nu = 3, lambda = NULL, a_leaf_mu = 3, a_leaf_tau = 3, + b_leaf_mu = NULL, b_leaf_tau = NULL, q = 0.9, sigma2 = NULL, variable_weights = NULL, + keep_vars_mu = NULL, drop_vars_mu = NULL, keep_vars_tau = NULL, drop_vars_tau = NULL, + num_trees_mu = 250, num_trees_tau = 50, num_gfr = 5, num_burnin = 0, num_mcmc = 100, + sample_sigma_global = T, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F, + propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5, + rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F) { # Variable weight preprocessing (and initialization if necessary) if (is.null(variable_weights)) { variable_weights = rep(1/ncol(X_train), ncol(X_train)) @@ -493,8 +496,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU rng <- createRNG(random_seed) # Sampling data structures - forest_model_mu <- createForestModel(forest_dataset_train, feature_types, num_trees_mu, nrow(X_train), alpha_mu, beta_mu, min_samples_leaf_mu) - forest_model_tau <- createForestModel(forest_dataset_train, feature_types, num_trees_tau, nrow(X_train), alpha_tau, beta_tau, min_samples_leaf_tau) + forest_model_mu <- createForestModel(forest_dataset_train, feature_types, num_trees_mu, nrow(X_train), alpha_mu, beta_mu, min_samples_leaf_mu, max_depth_mu) + forest_model_tau <- createForestModel(forest_dataset_train, feature_types, num_trees_tau, nrow(X_train), alpha_tau, beta_tau, min_samples_leaf_tau, max_depth_tau) # Container of forest samples forest_samples_mu <- createForestContainer(num_trees_mu, 1, T) diff --git a/R/cpp11.R b/R/cpp11.R index 1a4e6ac7..ac94d658 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -1,35 +1,35 @@ # Generated by cpp11: do not edit by hand -run_bart_cpp_basis_test_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { - .Call(`_stochtree_run_bart_cpp_basis_test_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +run_bart_cpp_basis_test_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_basis_test_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) } -run_bart_cpp_basis_test_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { - .Call(`_stochtree_run_bart_cpp_basis_test_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) +run_bart_cpp_basis_test_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_basis_test_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, covariates_test, basis_test, num_rows_test, num_covariates_test, num_basis_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) } -run_bart_cpp_basis_notest_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { - .Call(`_stochtree_run_bart_cpp_basis_notest_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +run_bart_cpp_basis_notest_rfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_basis_notest_rfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) } -run_bart_cpp_basis_notest_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { - .Call(`_stochtree_run_bart_cpp_basis_notest_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) +run_bart_cpp_basis_notest_norfx <- function(covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_basis_notest_norfx`, covariates_train, basis_train, outcome_train, num_rows_train, num_covariates_train, num_basis_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) } -run_bart_cpp_nobasis_test_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { - .Call(`_stochtree_run_bart_cpp_nobasis_test_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +run_bart_cpp_nobasis_test_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_nobasis_test_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, rfx_basis_test, rfx_group_labels_test, num_rfx_basis_test, num_rfx_groups_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) } -run_bart_cpp_nobasis_test_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { - .Call(`_stochtree_run_bart_cpp_nobasis_test_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) +run_bart_cpp_nobasis_test_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_nobasis_test_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, covariates_test, num_rows_test, num_covariates_test, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) } -run_bart_cpp_nobasis_notest_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { - .Call(`_stochtree_run_bart_cpp_nobasis_notest_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) +run_bart_cpp_nobasis_notest_rfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) { + .Call(`_stochtree_run_bart_cpp_nobasis_notest_rfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, rfx_basis_train, rfx_group_labels_train, num_rfx_basis_train, num_rfx_groups_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var, rfx_alpha_init, rfx_xi_init, rfx_sigma_alpha_init, rfx_sigma_xi_init, rfx_sigma_xi_shape, rfx_sigma_xi_scale) } -run_bart_cpp_nobasis_notest_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { - .Call(`_stochtree_run_bart_cpp_nobasis_notest_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) +run_bart_cpp_nobasis_notest_norfx <- function(covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) { + .Call(`_stochtree_run_bart_cpp_nobasis_notest_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) } create_forest_dataset_cpp <- function() { @@ -320,8 +320,8 @@ rng_cpp <- function(random_seed) { .Call(`_stochtree_rng_cpp`, random_seed) } -tree_prior_cpp <- function(alpha, beta, min_samples_leaf) { - .Call(`_stochtree_tree_prior_cpp`, alpha, beta, min_samples_leaf) +tree_prior_cpp <- function(alpha, beta, min_samples_leaf, max_depth) { + .Call(`_stochtree_tree_prior_cpp`, alpha, beta, min_samples_leaf, max_depth) } forest_tracker_cpp <- function(data, feature_types, num_trees, n) { diff --git a/R/model.R b/R/model.R index b39ae12d..ecebd206 100644 --- a/R/model.R +++ b/R/model.R @@ -50,11 +50,12 @@ ForestModel <- R6::R6Class( #' @param alpha Root node split probability in tree prior #' @param beta Depth prior penalty in tree prior #' @param min_samples_leaf Minimum number of samples in a tree leaf + #' @param max_depth Maximum depth of any tree in an ensemble #' @return A new `ForestModel` object. - initialize = function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf) { + initialize = function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) { stopifnot(!is.null(forest_dataset$data_ptr)) self$tracker_ptr <- forest_tracker_cpp(forest_dataset$data_ptr, feature_types, num_trees, n) - self$tree_prior_ptr <- tree_prior_cpp(alpha, beta, min_samples_leaf) + self$tree_prior_ptr <- tree_prior_cpp(alpha, beta, min_samples_leaf, max_depth) }, #' @description @@ -115,12 +116,13 @@ createRNG <- function(random_seed = -1){ #' @param alpha Root node split probability in tree prior #' @param beta Depth prior penalty in tree prior #' @param min_samples_leaf Minimum number of samples in a tree leaf +#' @param max_depth Maximum depth of any tree in an ensemble #' #' @return `ForestModel` object #' @export -createForestModel <- function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf) { +createForestModel <- function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) { return(invisible(( - ForestModel$new(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf) + ForestModel$new(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) ))) } diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index 717640ef..ed330ca3 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -434,6 +434,7 @@ void RunDebugDeconstructed(int dgp_num = 0, bool rfx_included = false, int num_g double alpha = 1; double beta = 0.1; int min_samples_leaf = 1; + int max_depth = 10; int cutpoint_grid_size = 100; double a_rfx = 1.; double b_rfx = 1.; @@ -458,7 +459,7 @@ void RunDebugDeconstructed(int dgp_num = 0, bool rfx_included = false, int num_g // Initialize tracker and tree prior ForestTracker tracker = ForestTracker(dataset.GetCovariates(), feature_types, num_trees, n); - TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf); + TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf, max_depth); // Initialize variance models GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index b6364744..3d40afaa 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -142,7 +142,7 @@ class BARTDispatcher { BARTResult& output, std::vector& feature_types, std::vector& variable_weights, int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, Eigen::MatrixXd& leaf_cov_init, double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, int cutpoint_grid_size, - bool sample_global_var, bool sample_leaf_var, int random_seed = -1 + bool sample_global_var, bool sample_leaf_var, int random_seed = -1, int max_depth = -1 ) { // Unpack sampling details num_gfr_ = num_gfr; @@ -209,7 +209,7 @@ class BARTDispatcher { // Initialize tracker and tree prior ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_); - TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf); + TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf, max_depth); // Initialize global variance model GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); diff --git a/include/stochtree/prior.h b/include/stochtree/prior.h index af095d2f..5d8686f7 100644 --- a/include/stochtree/prior.h +++ b/include/stochtree/prior.h @@ -42,22 +42,26 @@ class RandomEffectsRegressionGaussianPrior : public RandomEffectsGaussianPrior { class TreePrior { public: - TreePrior(double alpha, double beta, int32_t min_samples_in_leaf) { + TreePrior(double alpha, double beta, int32_t min_samples_in_leaf, int32_t max_depth = -1) { alpha_ = alpha; beta_ = beta; min_samples_in_leaf_ = min_samples_in_leaf; + max_depth_ = max_depth; } ~TreePrior() {} double GetAlpha() {return alpha_;} double GetBeta() {return beta_;} int32_t GetMinSamplesLeaf() {return min_samples_in_leaf_;} + int32_t GetMaxDepth() {return max_depth_;} void SetAlpha(double alpha) {alpha_ = alpha;} void SetBeta(double beta) {beta_ = beta;} void SetMinSamplesLeaf(int32_t min_samples_in_leaf) {min_samples_in_leaf_ = min_samples_in_leaf;} + void SetMaxDepth(int32_t max_depth) {max_depth_ = max_depth;} private: double alpha_; double beta_; int32_t min_samples_in_leaf_; + int32_t max_depth_; }; class IGVariancePrior { diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index e9ef65d9..2685361d 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -322,78 +322,87 @@ class MCMCForestSampler { int leaf_chosen = leaves[leaf_dist(gen)]; int leaf_depth = tree->GetDepth(leaf_chosen); - // Select a split variable at random - int p = dataset.GetCovariates().cols(); - CHECK_EQ(variable_weights.size(), p); - // std::vector var_weights(p); - // std::fill(var_weights.begin(), var_weights.end(), 1.0/p); - std::discrete_distribution<> var_dist(variable_weights.begin(), variable_weights.end()); - int var_chosen = var_dist(gen); - - // Determine the range of possible cutpoints - // TODO: specialize this for binary / ordered categorical / unordered categorical variables - double var_min, var_max; - VarSplitRange(tracker, dataset, tree_num, leaf_chosen, var_chosen, var_min, var_max); - if (var_max <= var_min) { - return; - } - - // Split based on var_min to var_max in a given node - std::uniform_real_distribution split_point_dist(var_min, var_max); - double split_point_chosen = split_point_dist(gen); + // Maximum leaf depth + int32_t max_depth = tree_prior.GetMaxDepth(); - // Create a split object - TreeSplit split = TreeSplit(split_point_chosen); - - // Compute the marginal likelihood of split and no split, given the leaf prior - std::tuple split_eval = leaf_model.EvaluateProposedSplit(dataset, tracker, residual, split, tree_num, leaf_chosen, var_chosen, global_variance); - double split_log_marginal_likelihood = std::get<0>(split_eval); - double no_split_log_marginal_likelihood = std::get<1>(split_eval); - int32_t left_n = std::get<2>(split_eval); - int32_t right_n = std::get<3>(split_eval); - - // Determine probability of growing the split node and its two new left and right nodes - double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta()); - double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); - double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); - - // Determine whether a "grow" move is possible from the newly formed tree - // in order to compute the probability of choosing "prune" from the new tree - // (which is always possible by construction) - bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen); - bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf(); - bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf(); - double prob_prune_new; - if (non_constant && (min_samples_left_check || min_samples_right_check)) { - prob_prune_new = 0.5; + // Terminate early if cannot be split + bool accept; + if ((leaf_depth >= max_depth) && (max_depth != -1)) { + accept = false; } else { - prob_prune_new = 1.0; - } - // Determine the number of leaves in the current tree and leaf parents in the proposed tree - int num_leaf_parents = tree->NumLeafParents(); - double p_leaf = 1/static_cast(num_leaves); - double p_leaf_parent = 1/static_cast(num_leaf_parents+1); + // Select a split variable at random + int p = dataset.GetCovariates().cols(); + CHECK_EQ(variable_weights.size(), p); + // std::vector var_weights(p); + // std::fill(var_weights.begin(), var_weights.end(), 1.0/p); + std::discrete_distribution<> var_dist(variable_weights.begin(), variable_weights.end()); + int var_chosen = var_dist(gen); + + // Determine the range of possible cutpoints + // TODO: specialize this for binary / ordered categorical / unordered categorical variables + double var_min, var_max; + VarSplitRange(tracker, dataset, tree_num, leaf_chosen, var_chosen, var_min, var_max); + if (var_max <= var_min) { + return; + } + + // Split based on var_min to var_max in a given node + std::uniform_real_distribution split_point_dist(var_min, var_max); + double split_point_chosen = split_point_dist(gen); + + // Create a split object + TreeSplit split = TreeSplit(split_point_chosen); + + // Compute the marginal likelihood of split and no split, given the leaf prior + std::tuple split_eval = leaf_model.EvaluateProposedSplit(dataset, tracker, residual, split, tree_num, leaf_chosen, var_chosen, global_variance); + double split_log_marginal_likelihood = std::get<0>(split_eval); + double no_split_log_marginal_likelihood = std::get<1>(split_eval); + int32_t left_n = std::get<2>(split_eval); + int32_t right_n = std::get<3>(split_eval); + + // Determine probability of growing the split node and its two new left and right nodes + double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta()); + double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); + double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta()); + + // Determine whether a "grow" move is possible from the newly formed tree + // in order to compute the probability of choosing "prune" from the new tree + // (which is always possible by construction) + bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen); + bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf(); + bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf(); + double prob_prune_new; + if (non_constant && (min_samples_left_check || min_samples_right_check)) { + prob_prune_new = 0.5; + } else { + prob_prune_new = 1.0; + } - // Compute the final MH ratio - double log_mh_ratio = ( - std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) + - std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood - ); - // Threshold at 0 - if (log_mh_ratio > 0) { - log_mh_ratio = 0; - } + // Determine the number of leaves in the current tree and leaf parents in the proposed tree + int num_leaf_parents = tree->NumLeafParents(); + double p_leaf = 1/static_cast(num_leaves); + double p_leaf_parent = 1/static_cast(num_leaf_parents+1); + + // Compute the final MH ratio + double log_mh_ratio = ( + std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) + + std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood + ); + // Threshold at 0 + if (log_mh_ratio > 0) { + log_mh_ratio = 0; + } - // Draw a uniform random variable and accept/reject the proposal on this basis - bool accept; - std::uniform_real_distribution mh_accept(0.0, 1.0); - double log_acceptance_prob = std::log(mh_accept(gen)); - if (log_acceptance_prob <= log_mh_ratio) { - accept = true; - AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false); - } else { - accept = false; + // Draw a uniform random variable and accept/reject the proposal on this basis + std::uniform_real_distribution mh_accept(0.0, 1.0); + double log_acceptance_prob = std::log(mh_accept(gen)); + if (log_acceptance_prob <= log_mh_ratio) { + accept = true; + AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false); + } else { + accept = false; + } } } @@ -575,89 +584,99 @@ class GFRForestSampler { std::unordered_map>& node_index_map, std::deque& split_queue, int node_id, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, std::vector& feature_types) { - std::vector log_cutpoint_evaluations; - std::vector cutpoint_features; - std::vector cutpoint_values; - std::vector cutpoint_feature_types; - StochTree::data_size_t valid_cutpoint_count; - CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); - EvaluateCutpoints(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, - cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, - cutpoint_grid_container); - // TODO: maybe add some checks here? - - // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood - double largest_mll = *std::max_element(log_cutpoint_evaluations.begin(), log_cutpoint_evaluations.end()); - std::vector cutpoint_evaluations(log_cutpoint_evaluations.size()); - for (data_size_t i = 0; i < log_cutpoint_evaluations.size(); i++){ - cutpoint_evaluations[i] = std::exp(log_cutpoint_evaluations[i] - largest_mll); - } - - // Sample the split (including a "no split" option) - std::discrete_distribution split_dist(cutpoint_evaluations.begin(), cutpoint_evaluations.end()); - data_size_t split_chosen = split_dist(gen); + // Leaf depth + int leaf_depth = tree->GetDepth(node_id); + + // Maximum leaf depth + int32_t max_depth = tree_prior.GetMaxDepth(); + + if ((max_depth == -1) || (leaf_depth < max_depth)) { - if (split_chosen == valid_cutpoint_count){ - // "No split" sampled, don't split or add any nodes to split queue - return; - } else { - // Split sampled - int feature_split = cutpoint_features[split_chosen]; - FeatureType feature_type = cutpoint_feature_types[split_chosen]; - double split_value = cutpoint_values[split_chosen]; - // Perform all of the relevant "split" operations in the model, tree and training dataset + // Cutpoint enumeration + std::vector log_cutpoint_evaluations; + std::vector cutpoint_features; + std::vector cutpoint_values; + std::vector cutpoint_feature_types; + StochTree::data_size_t valid_cutpoint_count; + CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); + EvaluateCutpoints(tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, + cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, + cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, + cutpoint_grid_container); + // TODO: maybe add some checks here? - // Compute node sample size - data_size_t node_n = node_end - node_begin; + // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood + double largest_mll = *std::max_element(log_cutpoint_evaluations.begin(), log_cutpoint_evaluations.end()); + std::vector cutpoint_evaluations(log_cutpoint_evaluations.size()); + for (data_size_t i = 0; i < log_cutpoint_evaluations.size(); i++){ + cutpoint_evaluations[i] = std::exp(log_cutpoint_evaluations[i] - largest_mll); + } - // Actual numeric cutpoint used for ordered categorical and numeric features - double split_value_numeric; - TreeSplit tree_split; + // Sample the split (including a "no split" option) + std::discrete_distribution split_dist(cutpoint_evaluations.begin(), cutpoint_evaluations.end()); + data_size_t split_chosen = split_dist(gen); - // We will use these later in the model expansion - data_size_t left_n = 0; - data_size_t right_n = 0; - data_size_t sort_idx; - double feature_value; - bool split_true; - - if (feature_type == FeatureType::kUnorderedCategorical) { - // Determine the number of categories available in a categorical split and the set of categories that route observations to the left node after split - int num_categories; - std::vector categories = cutpoint_grid_container.CutpointVector(static_cast(split_value), feature_split); - tree_split = TreeSplit(categories); - } else if (feature_type == FeatureType::kOrderedCategorical) { - // Convert the bin split to an actual split value - split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); - tree_split = TreeSplit(split_value_numeric); - } else if (feature_type == FeatureType::kNumeric) { - // Convert the bin split to an actual split value - split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); - tree_split = TreeSplit(split_value_numeric); + if (split_chosen == valid_cutpoint_count){ + // "No split" sampled, don't split or add any nodes to split queue + return; } else { - Log::Fatal("Invalid split type"); - } - - // Add split to tree and trackers - AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true); - - // Determine the number of observation in the newly created left node - int left_node = tree->LeftChild(node_id); - int right_node = tree->RightChild(node_id); - auto left_begin_iter = tracker.SortedNodeBeginIterator(left_node, feature_split); - auto left_end_iter = tracker.SortedNodeEndIterator(left_node, feature_split); - for (auto i = left_begin_iter; i < left_end_iter; i++) { - left_n += 1; - } + // Split sampled + int feature_split = cutpoint_features[split_chosen]; + FeatureType feature_type = cutpoint_feature_types[split_chosen]; + double split_value = cutpoint_values[split_chosen]; + // Perform all of the relevant "split" operations in the model, tree and training dataset + + // Compute node sample size + data_size_t node_n = node_end - node_begin; + + // Actual numeric cutpoint used for ordered categorical and numeric features + double split_value_numeric; + TreeSplit tree_split; + + // We will use these later in the model expansion + data_size_t left_n = 0; + data_size_t right_n = 0; + data_size_t sort_idx; + double feature_value; + bool split_true; + + if (feature_type == FeatureType::kUnorderedCategorical) { + // Determine the number of categories available in a categorical split and the set of categories that route observations to the left node after split + int num_categories; + std::vector categories = cutpoint_grid_container.CutpointVector(static_cast(split_value), feature_split); + tree_split = TreeSplit(categories); + } else if (feature_type == FeatureType::kOrderedCategorical) { + // Convert the bin split to an actual split value + split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); + tree_split = TreeSplit(split_value_numeric); + } else if (feature_type == FeatureType::kNumeric) { + // Convert the bin split to an actual split value + split_value_numeric = cutpoint_grid_container.CutpointValue(static_cast(split_value), feature_split); + tree_split = TreeSplit(split_value_numeric); + } else { + Log::Fatal("Invalid split type"); + } + + // Add split to tree and trackers + AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true); + + // Determine the number of observation in the newly created left node + int left_node = tree->LeftChild(node_id); + int right_node = tree->RightChild(node_id); + auto left_begin_iter = tracker.SortedNodeBeginIterator(left_node, feature_split); + auto left_end_iter = tracker.SortedNodeEndIterator(left_node, feature_split); + for (auto i = left_begin_iter; i < left_end_iter; i++) { + left_n += 1; + } - // Add the begin and end indices for the new left and right nodes to node_index_map - node_index_map.insert({left_node, std::make_pair(node_begin, node_begin + left_n)}); - node_index_map.insert({right_node, std::make_pair(node_begin + left_n, node_end)}); + // Add the begin and end indices for the new left and right nodes to node_index_map + node_index_map.insert({left_node, std::make_pair(node_begin, node_begin + left_n)}); + node_index_map.insert({right_node, std::make_pair(node_begin + left_n, node_end)}); - // Add the left and right nodes to the split tracker - split_queue.push_front(right_node); - split_queue.push_front(left_node); + // Add the left and right nodes to the split tracker + split_queue.push_front(right_node); + split_queue.push_front(left_node); + } } } diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index d30159f6..3a8cbc6b 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -37,7 +37,8 @@ Create a new ForestModel object. n, alpha, beta, - min_samples_leaf + min_samples_leaf, + max_depth )}\if{html}{\out{}} } @@ -57,6 +58,8 @@ Create a new ForestModel object. \item{\code{beta}}{Depth prior penalty in tree prior} \item{\code{min_samples_leaf}}{Minimum number of samples in a tree leaf} + +\item{\code{max_depth}}{Maximum depth of any tree in an ensemble} } \if{html}{\out{}} } diff --git a/man/bart.Rd b/man/bart.Rd index 1a004d36..f274891b 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -19,6 +19,7 @@ bart( alpha = 0.95, beta = 2, min_samples_leaf = 5, + max_depth = 10, leaf_model = 0, nu = 3, lambda = NULL, @@ -83,6 +84,8 @@ that were not in the training set.} \item{min_samples_leaf}{Minimum allowable size of a leaf, in terms of training samples. Default: 5.} +\item{max_depth}{Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with \code{-1} which does not enforce any depth limits on trees.} + \item{leaf_model}{Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.} \item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.} diff --git a/man/bart_specialized.Rd b/man/bart_specialized.Rd index ed44bf1f..5ac17209 100644 --- a/man/bart_specialized.Rd +++ b/man/bart_specialized.Rd @@ -19,6 +19,7 @@ bart_specialized( alpha = 0.95, beta = 2, min_samples_leaf = 5, + max_depth = 10, leaf_model = 0, nu = 3, lambda = NULL, @@ -85,6 +86,8 @@ that were not in the training set.} \item{min_samples_leaf}{Minimum allowable size of a leaf, in terms of training samples. Default: 5.} +\item{max_depth}{Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with \code{-1} which does not enforce any depth limits on trees.} + \item{leaf_model}{Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.} \item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.} diff --git a/man/bcf.Rd b/man/bcf.Rd index eba72ab9..9dac1939 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -25,6 +25,8 @@ bcf( beta_tau = 3, min_samples_leaf_mu = 5, min_samples_leaf_tau = 5, + max_depth_mu = 10, + max_depth_tau = 5, nu = 3, lambda = NULL, a_leaf_mu = 3, @@ -108,6 +110,10 @@ that were not in the training set.} \item{min_samples_leaf_tau}{Minimum allowable size of a leaf, in terms of training samples, for the treatment effect forest. Default: 5.} +\item{max_depth_mu}{Maximum depth of any tree in the mu ensemble. Default: 10. Can be overriden with \code{-1} which does not enforce any depth limits on trees.} + +\item{max_depth_tau}{Maximum depth of any tree in the tau ensemble. Default: 5. Can be overriden with \code{-1} which does not enforce any depth limits on trees.} + \item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.} \item{lambda}{Component of the scale parameter in the \code{IG(nu, nu*lambda)} global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).} diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index 4c37e37d..7218fb05 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -11,7 +11,8 @@ createForestModel( n, alpha, beta, - min_samples_leaf + min_samples_leaf, + max_depth ) } \arguments{ @@ -28,6 +29,8 @@ createForestModel( \item{beta}{Depth prior penalty in tree prior} \item{min_samples_leaf}{Minimum number of samples in a tree leaf} + +\item{max_depth}{Maximum depth of any tree in an ensemble} } \value{ \code{ForestModel} object diff --git a/src/R_bart.cpp b/src/R_bart.cpp index 6c7e1ab2..df65f7aa 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -25,7 +25,7 @@ cpp11::external_pointer run_bart_cpp_basis_test_rfx( int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, - int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, @@ -147,7 +147,7 @@ cpp11::external_pointer run_bart_cpp_basis_test_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else if (leaf_model_int == 1) { // Create the dispatcher and load the data @@ -169,7 +169,7 @@ cpp11::external_pointer run_bart_cpp_basis_test_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else { // Create the dispatcher and load the data @@ -191,7 +191,7 @@ cpp11::external_pointer run_bart_cpp_basis_test_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } @@ -211,7 +211,7 @@ cpp11::external_pointer run_bart_cpp_basis_test_norfx( cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, - int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var ) { @@ -271,7 +271,7 @@ cpp11::external_pointer run_bart_cpp_basis_test_norfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else if (leaf_model_int == 1) { // Create the dispatcher and load the data @@ -286,7 +286,7 @@ cpp11::external_pointer run_bart_cpp_basis_test_norfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else { // Create the dispatcher and load the data @@ -301,7 +301,7 @@ cpp11::external_pointer run_bart_cpp_basis_test_norfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } @@ -321,7 +321,7 @@ cpp11::external_pointer run_bart_cpp_basis_notest_rfx( cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, - int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, @@ -418,7 +418,7 @@ cpp11::external_pointer run_bart_cpp_basis_notest_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else if (leaf_model_int == 1) { // Create the dispatcher and load the data @@ -435,7 +435,7 @@ cpp11::external_pointer run_bart_cpp_basis_notest_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else { // Create the dispatcher and load the data @@ -452,7 +452,7 @@ cpp11::external_pointer run_bart_cpp_basis_notest_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } @@ -470,7 +470,7 @@ cpp11::external_pointer run_bart_cpp_basis_notest_norfx( cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, - int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var ) { @@ -520,7 +520,7 @@ cpp11::external_pointer run_bart_cpp_basis_notest_norfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else if (leaf_model_int == 1) { // Create the dispatcher and load the data @@ -533,7 +533,7 @@ cpp11::external_pointer run_bart_cpp_basis_notest_norfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else { // Create the dispatcher and load the data @@ -546,7 +546,7 @@ cpp11::external_pointer run_bart_cpp_basis_notest_norfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } @@ -569,7 +569,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_test_rfx( int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, - int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, @@ -686,7 +686,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_test_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else if (leaf_model_int == 1) { // Create the dispatcher and load the data @@ -708,7 +708,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_test_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else { // Create the dispatcher and load the data @@ -730,7 +730,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_test_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } @@ -750,7 +750,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_test_norfx( cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, - int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var ) { @@ -805,7 +805,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_test_norfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else if (leaf_model_int == 1) { // Create the dispatcher and load the data @@ -820,7 +820,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_test_norfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else { // Create the dispatcher and load the data @@ -835,7 +835,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_test_norfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } @@ -855,7 +855,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_notest_rfx( cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, - int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, @@ -951,7 +951,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_notest_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else if (leaf_model_int == 1) { // Create the dispatcher and load the data @@ -968,7 +968,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_notest_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else { // Create the dispatcher and load the data @@ -985,7 +985,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_notest_rfx( *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } @@ -1003,7 +1003,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_notest_norfx cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, - int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, + int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var ) { @@ -1052,7 +1052,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_notest_norfx *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else if (leaf_model_int == 1) { // Create the dispatcher and load the data @@ -1065,7 +1065,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_notest_norfx *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } else { // Create the dispatcher and load the data @@ -1078,7 +1078,7 @@ cpp11::external_pointer run_bart_cpp_nobasis_notest_norfx *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_cov, alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, - sample_global_var, sample_leaf_var, random_seed + sample_global_var, sample_leaf_var, random_seed, max_depth ); } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 39edd378..7fdcee69 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -6,59 +6,59 @@ #include // R_bart.cpp -cpp11::external_pointer run_bart_cpp_basis_test_rfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles covariates_test, cpp11::doubles basis_test, int num_rows_test, int num_covariates_test, int num_basis_test, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); -extern "C" SEXP _stochtree_run_bart_cpp_basis_test_rfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP covariates_test, SEXP basis_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP num_basis_test, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP rfx_basis_test, SEXP rfx_group_labels_test, SEXP num_rfx_basis_test, SEXP num_rfx_groups_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { +cpp11::external_pointer run_bart_cpp_basis_test_rfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles covariates_test, cpp11::doubles basis_test, int num_rows_test, int num_covariates_test, int num_basis_test, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_basis_test_rfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP covariates_test, SEXP basis_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP num_basis_test, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP rfx_basis_test, SEXP rfx_group_labels_test, SEXP num_rfx_basis_test, SEXP num_rfx_groups_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { BEGIN_CPP11 - return cpp11::as_sexp(run_bart_cpp_basis_test_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(basis_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(num_basis_test), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(rfx_basis_test), cpp11::as_cpp>(rfx_group_labels_test), cpp11::as_cpp>(num_rfx_basis_test), cpp11::as_cpp>(num_rfx_groups_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); + return cpp11::as_sexp(run_bart_cpp_basis_test_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(basis_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(num_basis_test), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(rfx_basis_test), cpp11::as_cpp>(rfx_group_labels_test), cpp11::as_cpp>(num_rfx_basis_test), cpp11::as_cpp>(num_rfx_groups_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); END_CPP11 } // R_bart.cpp -cpp11::external_pointer run_bart_cpp_basis_test_norfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles covariates_test, cpp11::doubles basis_test, int num_rows_test, int num_covariates_test, int num_basis_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); -extern "C" SEXP _stochtree_run_bart_cpp_basis_test_norfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP covariates_test, SEXP basis_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP num_basis_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { +cpp11::external_pointer run_bart_cpp_basis_test_norfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles covariates_test, cpp11::doubles basis_test, int num_rows_test, int num_covariates_test, int num_basis_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_basis_test_norfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP covariates_test, SEXP basis_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP num_basis_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { BEGIN_CPP11 - return cpp11::as_sexp(run_bart_cpp_basis_test_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(basis_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(num_basis_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); + return cpp11::as_sexp(run_bart_cpp_basis_test_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(basis_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(num_basis_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); END_CPP11 } // R_bart.cpp -cpp11::external_pointer run_bart_cpp_basis_notest_rfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); -extern "C" SEXP _stochtree_run_bart_cpp_basis_notest_rfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { +cpp11::external_pointer run_bart_cpp_basis_notest_rfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_basis_notest_rfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { BEGIN_CPP11 - return cpp11::as_sexp(run_bart_cpp_basis_notest_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); + return cpp11::as_sexp(run_bart_cpp_basis_notest_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); END_CPP11 } // R_bart.cpp -cpp11::external_pointer run_bart_cpp_basis_notest_norfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); -extern "C" SEXP _stochtree_run_bart_cpp_basis_notest_norfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { +cpp11::external_pointer run_bart_cpp_basis_notest_norfx(cpp11::doubles covariates_train, cpp11::doubles basis_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, int num_basis_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_basis_notest_norfx(SEXP covariates_train, SEXP basis_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP num_basis_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { BEGIN_CPP11 - return cpp11::as_sexp(run_bart_cpp_basis_notest_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); + return cpp11::as_sexp(run_bart_cpp_basis_notest_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(basis_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(num_basis_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); END_CPP11 } // R_bart.cpp -cpp11::external_pointer run_bart_cpp_nobasis_test_rfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles covariates_test, int num_rows_test, int num_covariates_test, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); -extern "C" SEXP _stochtree_run_bart_cpp_nobasis_test_rfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP covariates_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP rfx_basis_test, SEXP rfx_group_labels_test, SEXP num_rfx_basis_test, SEXP num_rfx_groups_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { +cpp11::external_pointer run_bart_cpp_nobasis_test_rfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles covariates_test, int num_rows_test, int num_covariates_test, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::doubles rfx_basis_test, cpp11::integers rfx_group_labels_test, int num_rfx_basis_test, int num_rfx_groups_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_test_rfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP covariates_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP rfx_basis_test, SEXP rfx_group_labels_test, SEXP num_rfx_basis_test, SEXP num_rfx_groups_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { BEGIN_CPP11 - return cpp11::as_sexp(run_bart_cpp_nobasis_test_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(rfx_basis_test), cpp11::as_cpp>(rfx_group_labels_test), cpp11::as_cpp>(num_rfx_basis_test), cpp11::as_cpp>(num_rfx_groups_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); + return cpp11::as_sexp(run_bart_cpp_nobasis_test_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(rfx_basis_test), cpp11::as_cpp>(rfx_group_labels_test), cpp11::as_cpp>(num_rfx_basis_test), cpp11::as_cpp>(num_rfx_groups_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); END_CPP11 } // R_bart.cpp -cpp11::external_pointer run_bart_cpp_nobasis_test_norfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles covariates_test, int num_rows_test, int num_covariates_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); -extern "C" SEXP _stochtree_run_bart_cpp_nobasis_test_norfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP covariates_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { +cpp11::external_pointer run_bart_cpp_nobasis_test_norfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles covariates_test, int num_rows_test, int num_covariates_test, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_test_norfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP covariates_test, SEXP num_rows_test, SEXP num_covariates_test, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { BEGIN_CPP11 - return cpp11::as_sexp(run_bart_cpp_nobasis_test_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); + return cpp11::as_sexp(run_bart_cpp_nobasis_test_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(covariates_test), cpp11::as_cpp>(num_rows_test), cpp11::as_cpp>(num_covariates_test), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); END_CPP11 } // R_bart.cpp -cpp11::external_pointer run_bart_cpp_nobasis_notest_rfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); -extern "C" SEXP _stochtree_run_bart_cpp_nobasis_notest_rfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { +cpp11::external_pointer run_bart_cpp_nobasis_notest_rfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::doubles rfx_basis_train, cpp11::integers rfx_group_labels_train, int num_rfx_basis_train, int num_rfx_groups_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var, cpp11::doubles rfx_alpha_init, cpp11::doubles_matrix<> rfx_xi_init, cpp11::doubles_matrix<> rfx_sigma_alpha_init, cpp11::doubles_matrix<> rfx_sigma_xi_init, double rfx_sigma_xi_shape, double rfx_sigma_xi_scale); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_notest_rfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP rfx_basis_train, SEXP rfx_group_labels_train, SEXP num_rfx_basis_train, SEXP num_rfx_groups_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var, SEXP rfx_alpha_init, SEXP rfx_xi_init, SEXP rfx_sigma_alpha_init, SEXP rfx_sigma_xi_init, SEXP rfx_sigma_xi_shape, SEXP rfx_sigma_xi_scale) { BEGIN_CPP11 - return cpp11::as_sexp(run_bart_cpp_nobasis_notest_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); + return cpp11::as_sexp(run_bart_cpp_nobasis_notest_rfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(rfx_basis_train), cpp11::as_cpp>(rfx_group_labels_train), cpp11::as_cpp>(num_rfx_basis_train), cpp11::as_cpp>(num_rfx_groups_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var), cpp11::as_cpp>(rfx_alpha_init), cpp11::as_cpp>>(rfx_xi_init), cpp11::as_cpp>>(rfx_sigma_alpha_init), cpp11::as_cpp>>(rfx_sigma_xi_init), cpp11::as_cpp>(rfx_sigma_xi_shape), cpp11::as_cpp>(rfx_sigma_xi_scale))); END_CPP11 } // R_bart.cpp -cpp11::external_pointer run_bart_cpp_nobasis_notest_norfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); -extern "C" SEXP _stochtree_run_bart_cpp_nobasis_notest_norfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { +cpp11::external_pointer run_bart_cpp_nobasis_notest_norfx(cpp11::doubles covariates_train, cpp11::doubles outcome_train, int num_rows_train, int num_covariates_train, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, double a_leaf, double b_leaf, double nu, double lamb, int min_samples_leaf, int max_depth, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_cov_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int leaf_model_int, bool sample_global_var, bool sample_leaf_var); +extern "C" SEXP _stochtree_run_bart_cpp_nobasis_notest_norfx(SEXP covariates_train, SEXP outcome_train, SEXP num_rows_train, SEXP num_covariates_train, SEXP feature_types, SEXP variable_weights, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP min_samples_leaf, SEXP max_depth, SEXP cutpoint_grid_size, SEXP leaf_cov_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP leaf_model_int, SEXP sample_global_var, SEXP sample_leaf_var) { BEGIN_CPP11 - return cpp11::as_sexp(run_bart_cpp_nobasis_notest_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); + return cpp11::as_sexp(run_bart_cpp_nobasis_notest_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); END_CPP11 } // R_data.cpp @@ -590,10 +590,10 @@ extern "C" SEXP _stochtree_rng_cpp(SEXP random_seed) { END_CPP11 } // sampler.cpp -cpp11::external_pointer tree_prior_cpp(double alpha, double beta, int min_samples_leaf); -extern "C" SEXP _stochtree_tree_prior_cpp(SEXP alpha, SEXP beta, SEXP min_samples_leaf) { +cpp11::external_pointer tree_prior_cpp(double alpha, double beta, int min_samples_leaf, int max_depth); +extern "C" SEXP _stochtree_tree_prior_cpp(SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP max_depth) { BEGIN_CPP11 - return cpp11::as_sexp(tree_prior_cpp(cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf))); + return cpp11::as_sexp(tree_prior_cpp(cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth))); END_CPP11 } // sampler.cpp @@ -925,21 +925,21 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, - {"_stochtree_run_bart_cpp_basis_notest_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_norfx, 28}, - {"_stochtree_run_bart_cpp_basis_notest_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_rfx, 38}, - {"_stochtree_run_bart_cpp_basis_test_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_norfx, 33}, - {"_stochtree_run_bart_cpp_basis_test_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_rfx, 47}, - {"_stochtree_run_bart_cpp_nobasis_notest_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_norfx, 26}, - {"_stochtree_run_bart_cpp_nobasis_notest_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_rfx, 36}, - {"_stochtree_run_bart_cpp_nobasis_test_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_norfx, 29}, - {"_stochtree_run_bart_cpp_nobasis_test_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_rfx, 43}, + {"_stochtree_run_bart_cpp_basis_notest_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_norfx, 29}, + {"_stochtree_run_bart_cpp_basis_notest_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_rfx, 39}, + {"_stochtree_run_bart_cpp_basis_test_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_norfx, 34}, + {"_stochtree_run_bart_cpp_basis_test_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_rfx, 48}, + {"_stochtree_run_bart_cpp_nobasis_notest_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_norfx, 27}, + {"_stochtree_run_bart_cpp_nobasis_notest_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_rfx, 37}, + {"_stochtree_run_bart_cpp_nobasis_test_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_norfx, 30}, + {"_stochtree_run_bart_cpp_nobasis_test_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_rfx, 44}, {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 13}, {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 13}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 4}, {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 5}, {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, - {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 3}, + {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, {"_stochtree_update_residual_forest_container_cpp", (DL_FUNC) &_stochtree_update_residual_forest_container_cpp, 7}, {NULL, NULL, 0} }; diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index cb329f13..daf04dac 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -255,7 +255,7 @@ class ForestContainerCpp { class ForestSamplerCpp { public: - ForestSamplerCpp(ForestDatasetCpp& dataset, py::array_t feature_types, int num_trees, data_size_t num_obs, double alpha, double beta, int min_samples_leaf) { + ForestSamplerCpp(ForestDatasetCpp& dataset, py::array_t feature_types, int num_trees, data_size_t num_obs, double alpha, double beta, int min_samples_leaf, int max_depth = -1) { // Convert vector of integers to std::vector of enum FeatureType std::vector feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { @@ -265,7 +265,7 @@ class ForestSamplerCpp { // Initialize pointer to C++ ForestTracker and TreePrior classes StochTree::ForestDataset* dataset_ptr = dataset.GetDataset(); tracker_ = std::make_unique(dataset_ptr->GetCovariates(), feature_types_, num_trees, num_obs); - split_prior_ = std::make_unique(alpha, beta, min_samples_leaf); + split_prior_ = std::make_unique(alpha, beta, min_samples_leaf, max_depth); } ~ForestSamplerCpp() {} diff --git a/src/sampler.cpp b/src/sampler.cpp index 3347e913..0edf6a7a 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -173,9 +173,9 @@ cpp11::external_pointer rng_cpp(int random_seed = -1) { } [[cpp11::register]] -cpp11::external_pointer tree_prior_cpp(double alpha, double beta, int min_samples_leaf) { +cpp11::external_pointer tree_prior_cpp(double alpha, double beta, int min_samples_leaf, int max_depth = -1) { // Create smart pointer to newly allocated object - std::unique_ptr prior_ptr_ = std::make_unique(alpha, beta, min_samples_leaf); + std::unique_ptr prior_ptr_ = std::make_unique(alpha, beta, min_samples_leaf, max_depth); // Release management of the pointer to R session return cpp11::external_pointer(prior_ptr_.release()); diff --git a/stochtree/bart.py b/stochtree/bart.py index 963c47ee..9169de05 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -24,7 +24,7 @@ def is_sampled(self) -> bool: return self.sampled def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = None, X_test: np.array = None, basis_test: np.array = None, - cutpoint_grid_size = 100, sigma_leaf: float = None, alpha: float = 0.95, beta: float = 2.0, min_samples_leaf: int = 5, + cutpoint_grid_size = 100, sigma_leaf: float = None, alpha: float = 0.95, beta: float = 2.0, min_samples_leaf: int = 5, max_depth: int = 10, nu: float = 3, lamb: float = None, a_leaf: float = 3, b_leaf: float = None, q: float = 0.9, sigma2: float = None, num_trees: int = 200, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, sample_sigma_global: bool = True, sample_sigma_leaf: bool = True, random_seed: int = -1, keep_burnin: bool = False, keep_gfr: bool = False) -> None: @@ -56,6 +56,8 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. min_samples_leaf : :obj:`int`, optional Minimum allowable size of a leaf, in terms of training samples. Defaults to ``5``. + max_depth : :obj:`int`, optional + Maximum depth of any tree in the ensemble. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. nu : :obj:`float`, optional Shape parameter in the ``IG(nu, nu*lamb)`` global error variance model. Defaults to ``3``. lamb : :obj:`float`, optional @@ -217,7 +219,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N cpp_rng = RNG(random_seed) # Sampling data structures - forest_sampler = ForestSampler(forest_dataset_train, feature_types, num_trees, self.n_train, alpha, beta, min_samples_leaf) + forest_sampler = ForestSampler(forest_dataset_train, feature_types, num_trees, self.n_train, alpha, beta, min_samples_leaf, max_depth) # Determine the leaf model if not self.has_basis: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index bfc334d0..e1e54535 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -32,8 +32,8 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr X_test: Union[pd.DataFrame, np.array] = None, Z_test: np.array = None, pi_test: np.array = None, cutpoint_grid_size = 100, sigma_leaf_mu: float = None, sigma_leaf_tau: float = None, alpha_mu: float = 0.95, alpha_tau: float = 0.25, beta_mu: float = 2.0, beta_tau: float = 3.0, - min_samples_leaf_mu: int = 5, min_samples_leaf_tau: int = 5, nu: float = 3, lamb: float = None, - a_leaf_mu: float = 3, a_leaf_tau: float = 3, b_leaf_mu: float = None, b_leaf_tau: float = None, + min_samples_leaf_mu: int = 5, min_samples_leaf_tau: int = 5, max_depth_mu: int = 10, max_depth_tau: int = 5, + nu: float = 3, lamb: float = None, a_leaf_mu: float = 3, a_leaf_tau: float = 3, b_leaf_mu: float = None, b_leaf_tau: float = None, q: float = 0.9, sigma2: float = None, variable_weights: np.array = None, keep_vars_mu: Union[list, np.array] = None, drop_vars_mu: Union[list, np.array] = None, keep_vars_tau: Union[list, np.array] = None, drop_vars_tau: Union[list, np.array] = None, @@ -88,6 +88,10 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr Minimum allowable size of a leaf, in terms of training samples, for the prognostic forest. Defaults to ``5``. min_samples_leaf_tau : :obj:`int`, optional Minimum allowable size of a leaf, in terms of training samples, for the treatment effect forest. Defaults to ``5``. + max_depth_mu : :obj:`int`, optional + Maximum depth of any tree in the mu ensemble. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. + max_depth_tau : :obj:`int`, optional + Maximum depth of any tree in the tau ensemble. Defaults to ``5``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. nu : :obj:`float`, optional Shape parameter in the ``IG(nu, nu*lamb)`` global error variance model. Defaults to ``3``. lamb : :obj:`float`, optional @@ -614,8 +618,8 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr cpp_rng = RNG(random_seed) # Sampling data structures - forest_sampler_mu = ForestSampler(forest_dataset_train, feature_types, num_trees_mu, self.n_train, alpha_mu, beta_mu, min_samples_leaf_mu) - forest_sampler_tau = ForestSampler(forest_dataset_train, feature_types, num_trees_tau, self.n_train, alpha_tau, beta_tau, min_samples_leaf_tau) + forest_sampler_mu = ForestSampler(forest_dataset_train, feature_types, num_trees_mu, self.n_train, alpha_mu, beta_mu, min_samples_leaf_mu, max_depth_mu) + forest_sampler_tau = ForestSampler(forest_dataset_train, feature_types, num_trees_tau, self.n_train, alpha_tau, beta_tau, min_samples_leaf_tau, max_depth_tau) # Container of forest samples self.forest_container_mu = ForestContainer(num_trees_mu, 1, True) diff --git a/stochtree/sampler.py b/stochtree/sampler.py index 8b1a1abe..0a090a09 100644 --- a/stochtree/sampler.py +++ b/stochtree/sampler.py @@ -13,9 +13,9 @@ def __init__(self, random_seed: int) -> None: class ForestSampler: - def __init__(self, dataset: Dataset, feature_types: np.array, num_trees: int, num_obs: int, alpha: float, beta: float, min_samples_leaf: int) -> None: + def __init__(self, dataset: Dataset, feature_types: np.array, num_trees: int, num_obs: int, alpha: float, beta: float, min_samples_leaf: int, max_depth: int) -> None: # Initialize a ForestDatasetCpp object - self.forest_sampler_cpp = ForestSamplerCpp(dataset.dataset_cpp, feature_types, num_trees, num_obs, alpha, beta, min_samples_leaf) + self.forest_sampler_cpp = ForestSamplerCpp(dataset.dataset_cpp, feature_types, num_trees, num_obs, alpha, beta, min_samples_leaf, max_depth) def sample_one_iteration(self, forest_container: ForestContainer, dataset: Dataset, residual: Residual, rng: RNG, feature_types: np.array, cutpoint_grid_size: int, leaf_model_scale_input: np.array, diff --git a/test/cpp/test_model.cpp b/test/cpp/test_model.cpp index 52742cff..23e2e929 100644 --- a/test/cpp/test_model.cpp +++ b/test/cpp/test_model.cpp @@ -31,10 +31,11 @@ TEST(LeafConstantModel, FullEnumeration) { double alpha = 0.95; double beta = 1.25; int min_samples_leaf = 1; + int max_depth = -1; double global_variance = 1.; double tau = 1.; int cutpoint_grid_size = n; - StochTree::TreePrior tree_prior = StochTree::TreePrior(alpha, beta, min_samples_leaf); + StochTree::TreePrior tree_prior = StochTree::TreePrior(alpha, beta, min_samples_leaf, max_depth); // Construct temporary data structures needed to enumerate splits std::vector log_cutpoint_evaluations; @@ -88,10 +89,11 @@ TEST(LeafConstantModel, CutpointThinning) { double alpha = 0.95; double beta = 1.25; int min_samples_leaf = 1; + int max_depth = -1; double global_variance = 1.; double tau = 1.; int cutpoint_grid_size = 5; - StochTree::TreePrior tree_prior = StochTree::TreePrior(alpha, beta, min_samples_leaf); + StochTree::TreePrior tree_prior = StochTree::TreePrior(alpha, beta, min_samples_leaf, max_depth); // Construct temporary data structures needed to enumerate splits std::vector log_cutpoint_evaluations; @@ -145,10 +147,11 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) { double alpha = 0.95; double beta = 1.25; int min_samples_leaf = 1; + int max_depth = -1; double global_variance = 1.; double tau = 1.; int cutpoint_grid_size = n; - StochTree::TreePrior tree_prior = StochTree::TreePrior(alpha, beta, min_samples_leaf); + StochTree::TreePrior tree_prior = StochTree::TreePrior(alpha, beta, min_samples_leaf, max_depth); // Construct temporary data structures needed to enumerate splits std::vector log_cutpoint_evaluations; @@ -203,10 +206,11 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) { double alpha = 0.95; double beta = 1.25; int min_samples_leaf = 1; + int max_depth = -1; double global_variance = 1.; double tau = 1.; int cutpoint_grid_size = 5; - StochTree::TreePrior tree_prior = StochTree::TreePrior(alpha, beta, min_samples_leaf); + StochTree::TreePrior tree_prior = StochTree::TreePrior(alpha, beta, min_samples_leaf, max_depth); // Construct temporary data structures needed to enumerate splits std::vector log_cutpoint_evaluations; From 64890f59e3bc8d3f726fc34cd8195fc55d1623db Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 19 Jul 2024 01:35:19 -0400 Subject: [PATCH 11/18] Added back the "streamlined" C++ loop for comparison --- NAMESPACE | 3 +- R/bart.R | 217 +++++++++++++++++- R/cpp11.R | 4 + include/stochtree/bart.h | 155 +++++++++++++ ...alized.Rd => bart_cpp_loop_generalized.Rd} | 6 +- man/bart_cpp_loop_specialized.Rd | 121 ++++++++++ src/R_bart.cpp | 47 ++++ src/cpp11.cpp | 8 + 8 files changed, 545 insertions(+), 16 deletions(-) rename man/{bart_specialized.Rd => bart_cpp_loop_generalized.Rd} (98%) create mode 100644 man/bart_cpp_loop_specialized.Rd diff --git a/NAMESPACE b/NAMESPACE index 4b3af259..03201872 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -5,7 +5,8 @@ S3method(getRandomEffectSamples,bcf) S3method(predict,bartmodel) S3method(predict,bcf) export(bart) -export(bart_specialized) +export(bart_cpp_loop_generalized) +export(bart_cpp_loop_specialized) export(bcf) export(computeForestKernels) export(computeForestLeafIndices) diff --git a/R/bart.R b/R/bart.R index cc2cd1ab..d129e1cf 100644 --- a/R/bart.R +++ b/R/bart.R @@ -703,17 +703,18 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL #' bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test) #' # plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") #' # abline(0,1,col="red",lty=3,lwd=3) -bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, - rfx_basis_train = NULL, X_test = NULL, W_test = NULL, - group_ids_test = NULL, rfx_basis_test = NULL, - cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, - beta = 2.0, min_samples_leaf = 5, max_depth = 10, leaf_model = 0, - nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, - q = 0.9, sigma2_init = NULL, variable_weights = NULL, - num_trees = 200, num_gfr = 5, num_burnin = 0, - num_mcmc = 100, sample_sigma = T, sample_tau = T, - random_seed = -1, keep_burnin = F, keep_gfr = F, - verbose = F, sample_global_var = T, sample_leaf_var = F){ +bart_cpp_loop_generalized <- function( + X_train, y_train, W_train = NULL, group_ids_train = NULL, + rfx_basis_train = NULL, X_test = NULL, W_test = NULL, + group_ids_test = NULL, rfx_basis_test = NULL, + cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, + beta = 2.0, min_samples_leaf = 5, max_depth = 10, leaf_model = 0, + nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, + q = 0.9, sigma2_init = NULL, variable_weights = NULL, + num_trees = 200, num_gfr = 5, num_burnin = 0, + num_mcmc = 100, sample_sigma = T, sample_tau = T, + random_seed = -1, keep_burnin = F, keep_gfr = F, + verbose = F, sample_global_var = T, sample_leaf_var = F){ # Variable weight preprocessing (and initialization if necessary) if (is.null(variable_weights)) { variable_weights = rep(1/ncol(X_train), ncol(X_train)) @@ -1047,7 +1048,199 @@ bart_specialized <- function(X_train, y_train, W_train = NULL, group_ids_train = ) # if (has_test) result[["y_hat_test"]] = y_hat_test # if (sample_sigma) result[["sigma2_samples"]] = sigma2_samples - class(result) <- "simplifiedbart" + class(result) <- "bartcppgeneralized" + + return(result) +} + +#' Run the BART algorithm for supervised learning. +#' +#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. +#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be +#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, +#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata +#' that the column is ordered categorical). +#' @param y_train Outcome to be modeled by the ensemble. +#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. +#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with +#' that of `X_train`. +#' @param cutpoint_grid_size Maximum size of the "grid" of potential cutpoints to consider. Default: 100. +#' @param tau_init Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. +#' @param alpha Prior probability of splitting for a tree of depth 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. +#' @param beta Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. +#' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0. +#' @param min_samples_leaf Minimum allowable size of a leaf, in terms of training samples. Default: 5. +#' @param max_depth Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees. +#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3. +#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). +#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3. +#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. +#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9. +#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here. +#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. +#' @param num_trees Number of trees in the ensemble. Default: 200. +#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. +#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0. +#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. +#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. +#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0. +#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0. +#' @param verbose Whether or not to print progress during the sampling loops. Default: FALSE. +#' +#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test) +#' # plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") +#' # abline(0,1,col="red",lty=3,lwd=3) +bart_cpp_loop_specialized <- function( + X_train, y_train, X_test = NULL, cutpoint_grid_size = 100, + tau_init = NULL, alpha = 0.95, beta = 2.0, min_samples_leaf = 5, + max_depth = -1, nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, + q = 0.9, sigma2_init = NULL, variable_weights = NULL, + num_trees = 200, num_gfr = 5, num_burnin = 0, num_mcmc = 100, + random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F +){ + # Variable weight preprocessing (and initialization if necessary) + if (is.null(variable_weights)) { + variable_weights = rep(1/ncol(X_train), ncol(X_train)) + } + if (any(variable_weights < 0)) { + stop("variable_weights cannot have any negative weights") + } + + # Preprocess covariates + if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { + stop("X_train must be a matrix or dataframe") + } + if (!is.null(X_test)){ + if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { + stop("X_test must be a matrix or dataframe") + } + } + if (ncol(X_train) != length(variable_weights)) { + stop("length(variable_weights) must equal ncol(X_train)") + } + train_cov_preprocess_list <- preprocessTrainData(X_train) + X_train_metadata <- train_cov_preprocess_list$metadata + X_train <- train_cov_preprocess_list$data + original_var_indices <- X_train_metadata$original_var_indices + feature_types <- X_train_metadata$feature_types + feature_types <- as.integer(feature_types) + if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata) + + # Update variable weights + variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x)) + variable_weights <- variable_weights[original_var_indices]*variable_weights_adj + + # Data consistency checks + if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { + stop("X_train and X_test must have the same number of columns") + } + if (nrow(X_train) != length(y_train)) { + stop("X_train and y_train must have the same number of observations") + } + + # Convert y_train to numeric vector if not already converted + if (!is.null(dim(y_train))) { + y_train <- as.matrix(y_train) + } + + # Determine whether a basis vector is provided + has_basis = F + + # Determine whether a test set is provided + has_test = !is.null(X_test) + + # Standardize outcome separately for test and train + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) + resid_train <- (y_train-y_bar_train)/y_std_train + + # Calibrate priors for sigma^2 and tau + reg_basis <- X_train + sigma2hat <- (sigma(lm(resid_train~reg_basis)))^2 + quantile_cutoff <- 0.9 + if (is.null(lambda)) { + lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu + } + if (is.null(sigma2_init)) sigma2_init <- sigma2hat + if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees) + if (is.null(tau_init)) tau_init <- var(resid_train)/(num_trees) + current_leaf_scale <- as.matrix(tau_init) + current_sigma2 <- sigma2_init + + # Determine leaf model type + leaf_model <- 0 + + # Unpack model type info + output_dimension = 1 + is_leaf_constant = T + leaf_regression = F + + # Container of variance parameter samples + num_samples <- num_gfr + num_burnin + num_mcmc + + # Run the BART sampler + bart_result_ptr <- run_bart_specialized_cpp( + as.numeric(X_train), y_train, feature_types, variable_weights, nrow(X_train), + ncol(X_train), num_trees, output_dimension, is_leaf_constant, alpha, beta, + min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lambda, + tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth + ) + + # Return results as a list + model_params <- list( + "sigma2_init" = sigma2_init, + "nu" = nu, + "lambda" = lambda, + "tau_init" = tau_init, + "a" = a_leaf, + "b" = b_leaf, + "outcome_mean" = y_bar_train, + "outcome_scale" = y_std_train, + "output_dimension" = output_dimension, + "is_leaf_constant" = is_leaf_constant, + "leaf_regression" = leaf_regression, + "requires_basis" = F, + "num_covariates" = ncol(X_train), + "num_basis" = 0, + "num_samples" = num_samples, + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, + "has_basis" = F, + "has_rfx" = F, + "has_rfx_basis" = F, + "num_rfx_basis" = 0, + "sample_sigma" = T, + "sample_tau" = F + ) + result <- list( + "model_params" = model_params + ) + class(result) <- "bartcppsimplified" return(result) } diff --git a/R/cpp11.R b/R/cpp11.R index ac94d658..9e546e78 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -32,6 +32,10 @@ run_bart_cpp_nobasis_notest_norfx <- function(covariates_train, outcome_train, n .Call(`_stochtree_run_bart_cpp_nobasis_notest_norfx`, covariates_train, outcome_train, num_rows_train, num_covariates_train, feature_types, variable_weights, num_trees, output_dimension, is_leaf_constant, alpha, beta, a_leaf, b_leaf, nu, lamb, min_samples_leaf, max_depth, cutpoint_grid_size, leaf_cov_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, leaf_model_int, sample_global_var, sample_leaf_var) } +run_bart_specialized_cpp <- function(covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth) { + .Call(`_stochtree_run_bart_specialized_cpp`, covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth) +} + create_forest_dataset_cpp <- function() { .Call(`_stochtree_create_forest_dataset_cpp`) } diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index 3d40afaa..1441cafb 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -333,6 +333,161 @@ class BARTDispatcher { MultivariateRegressionRandomEffectsModel rfx_model_; }; +class BARTResultSimplified { + public: + BARTResultSimplified(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) : + forests_samples_{num_trees, output_dimension, is_leaf_constant} {} + ~BARTResultSimplified() {} + ForestContainer& GetForests() {return forests_samples_;} + std::vector& GetTrainPreds() {return raw_preds_train_;} + std::vector& GetTestPreds() {return raw_preds_test_;} + std::vector& GetVarianceSamples() {return sigma_samples_;} + int NumGFRSamples() {return num_gfr_;} + int NumBurninSamples() {return num_burnin_;} + int NumMCMCSamples() {return num_mcmc_;} + int NumTrainObservations() {return num_train_;} + int NumTestObservations() {return num_test_;} + bool HasTestSet() {return has_test_set_;} + private: + ForestContainer forests_samples_; + std::vector raw_preds_train_; + std::vector raw_preds_test_; + std::vector sigma_samples_; + int num_gfr_{0}; + int num_burnin_{0}; + int num_mcmc_{0}; + int num_train_{0}; + int num_test_{0}; + bool has_test_set_{false}; +}; + +class BARTDispatcherSimplified { + public: + BARTDispatcherSimplified() {} + ~BARTDispatcherSimplified() {} + BARTResultSimplified CreateOutputObject(int num_trees, int output_dimension = 1, bool is_leaf_constant = true) { + return BARTResultSimplified(num_trees, output_dimension, is_leaf_constant); + } + void RunSampler( + BARTResultSimplified& output, std::vector& feature_types, std::vector& variable_weights, + int num_trees, int num_gfr, int num_burnin, int num_mcmc, double global_var_init, double leaf_var_init, + double alpha, double beta, double nu, double lamb, double a_leaf, double b_leaf, int min_samples_leaf, + int cutpoint_grid_size, int random_seed = -1, int max_depth = -1 + ) { + // Unpack sampling details + num_gfr_ = num_gfr; + num_burnin_ = num_burnin; + num_mcmc_ = num_mcmc; + int num_samples = num_gfr + num_burnin + num_mcmc; + + // Random number generation + std::mt19937 rng; + if (random_seed == -1) { + std::random_device rd; + std::mt19937 rng(rd()); + } + else { + std::mt19937 rng(random_seed); + } + + // Obtain references to forest / parameter samples and predictions in BARTResult + ForestContainer& forest_samples = output.GetForests(); + std::vector& sigma2_samples = output.GetVarianceSamples(); + std::vector& train_preds = output.GetTrainPreds(); + std::vector& test_preds = output.GetTestPreds(); + + // Clear and prepare vectors to store results + sigma2_samples.clear(); + train_preds.clear(); + test_preds.clear(); + sigma2_samples.resize(num_samples); + train_preds.resize(num_samples*num_train_); + if (has_test_set_) test_preds.resize(num_samples*num_test_); + + // Initialize tracker and tree prior + ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_); + TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf); + + // Initialize variance model + GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); + + // Initialize leaf model and samplers + GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_var_init); + GFRForestSampler gfr_sampler = GFRForestSampler(cutpoint_grid_size); + MCMCForestSampler mcmc_sampler = MCMCForestSampler(); + + // Running variable for current sampled value of global outcome variance parameter + double global_var = global_var_init; + + // Run the XBART Gibbs sampler + int iter = 0; + if (num_gfr > 0) { + for (int i = 0; i < num_gfr; i++) { + // Sample the forests + gfr_sampler.SampleOneIter(tracker, forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, + rng, variable_weights, global_var, feature_types, false); + + // Sample the global outcome + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + + // Increment sample counter + iter++; + } + } + + // Run the MCMC sampler + if (num_burnin + num_mcmc > 0) { + for (int i = 0; i < num_burnin + num_mcmc; i++) { + // Sample the forests + mcmc_sampler.SampleOneIter(tracker, forest_samples, leaf_model, train_dataset_, train_outcome_, tree_prior, + rng, variable_weights, global_var, true); + + // Sample the global outcome + global_var = global_var_model.SampleVarianceParameter(train_outcome_.GetData(), nu, lamb, rng); + sigma2_samples.at(iter) = global_var; + + // Increment sample counter + iter++; + } + } + + // Predict forests + forest_samples.PredictInPlace(train_dataset_, train_preds); + if (has_test_set_) forest_samples.PredictInPlace(test_dataset_, test_preds); + } + void AddDataset(double* covariates, data_size_t num_row, int num_col, bool is_row_major, bool train) { + if (train) { + train_dataset_ = ForestDataset(); + train_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); + num_train_ = num_row; + } else { + test_dataset_ = ForestDataset(); + test_dataset_.AddCovariates(covariates, num_row, num_col, is_row_major); + has_test_set_ = true; + num_test_ = num_row; + } + } + void AddTrainOutcome(double* outcome, data_size_t num_row) { + train_outcome_ = ColumnVector(); + train_outcome_.LoadData(outcome, num_row); + } + private: + // Sampling details + int num_gfr_{0}; + int num_burnin_{0}; + int num_mcmc_{0}; + int num_train_{0}; + int num_test_{0}; + bool has_test_set_{false}; + + // Sampling data objects + ForestDataset train_dataset_; + ForestDataset test_dataset_; + ColumnVector train_outcome_; +}; + + } // namespace StochTree #endif // STOCHTREE_SAMPLING_DISPATCH_H_ diff --git a/man/bart_specialized.Rd b/man/bart_cpp_loop_generalized.Rd similarity index 98% rename from man/bart_specialized.Rd rename to man/bart_cpp_loop_generalized.Rd index 5ac17209..aa6faf66 100644 --- a/man/bart_specialized.Rd +++ b/man/bart_cpp_loop_generalized.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/bart.R -\name{bart_specialized} -\alias{bart_specialized} +\name{bart_cpp_loop_generalized} +\alias{bart_cpp_loop_generalized} \title{Run the BART algorithm for supervised learning.} \usage{ -bart_specialized( +bart_cpp_loop_generalized( X_train, y_train, W_train = NULL, diff --git a/man/bart_cpp_loop_specialized.Rd b/man/bart_cpp_loop_specialized.Rd new file mode 100644 index 00000000..03066106 --- /dev/null +++ b/man/bart_cpp_loop_specialized.Rd @@ -0,0 +1,121 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{bart_cpp_loop_specialized} +\alias{bart_cpp_loop_specialized} +\title{Run the BART algorithm for supervised learning.} +\usage{ +bart_cpp_loop_specialized( + X_train, + y_train, + X_test = NULL, + cutpoint_grid_size = 100, + tau_init = NULL, + alpha = 0.95, + beta = 2, + min_samples_leaf = 5, + max_depth = -1, + nu = 3, + lambda = NULL, + a_leaf = 3, + b_leaf = NULL, + q = 0.9, + sigma2_init = NULL, + variable_weights = NULL, + num_trees = 200, + num_gfr = 5, + num_burnin = 0, + num_mcmc = 100, + random_seed = -1, + keep_burnin = F, + keep_gfr = F, + verbose = F +) +} +\arguments{ +\item{X_train}{Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. +Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be +preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, +categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata +that the column is ordered categorical).} + +\item{y_train}{Outcome to be modeled by the ensemble.} + +\item{X_test}{(Optional) Test set of covariates used to define "out of sample" evaluation data. +May be provided either as a dataframe or a matrix, but the format of \code{X_test} must be consistent with +that of \code{X_train}.} + +\item{cutpoint_grid_size}{Maximum size of the "grid" of potential cutpoints to consider. Default: 100.} + +\item{tau_init}{Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here.} + +\item{alpha}{Prior probability of splitting for a tree of depth 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.} + +\item{beta}{Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}.} + +\item{min_samples_leaf}{Minimum allowable size of a leaf, in terms of training samples. Default: 5.} + +\item{max_depth}{Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with \code{-1} which does not enforce any depth limits on trees.} + +\item{nu}{Shape parameter in the \code{IG(nu, nu*lambda)} global error variance model. Default: 3.} + +\item{lambda}{Component of the scale parameter in the \code{IG(nu, nu*lambda)} global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).} + +\item{a_leaf}{Shape parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Default: 3.} + +\item{b_leaf}{Scale parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Calibrated internally as \code{0.5/num_trees} if not set here.} + +\item{q}{Quantile used to calibrated \code{lambda} as in Sparapani et al (2021). Default: 0.9.} + +\item{sigma2_init}{Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.} + +\item{variable_weights}{Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here.} + +\item{num_trees}{Number of trees in the ensemble. Default: 200.} + +\item{num_gfr}{Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.} + +\item{num_burnin}{Number of "burn-in" iterations of the MCMC sampler. Default: 0.} + +\item{num_mcmc}{Number of "retained" iterations of the MCMC sampler. Default: 100.} + +\item{random_seed}{Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}.} + +\item{keep_burnin}{Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.} + +\item{keep_gfr}{Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.} + +\item{verbose}{Whether or not to print progress during the sampling loops. Default: FALSE.} + +\item{leaf_model}{Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.} +} +\value{ +List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). +} +\description{ +Run the BART algorithm for supervised learning. +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart_specialized(X_train = X_train, y_train = y_train, X_test = X_test) +# plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual") +# abline(0,1,col="red",lty=3,lwd=3) +} diff --git a/src/R_bart.cpp b/src/R_bart.cpp index df65f7aa..0e87c948 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -1088,3 +1088,50 @@ cpp11::external_pointer run_bart_cpp_nobasis_notest_norfx // Release management of the pointer to R session return cpp11::external_pointer(bart_result_ptr_.release()); } + +[[cpp11::register]] +cpp11::external_pointer run_bart_specialized_cpp( + cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, + cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, + int output_dimension, bool is_leaf_constant, double alpha, double beta, + int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, + double nu, double lamb, double leaf_variance_init, double global_variance_init, + int num_gfr, int num_burnin, int num_mcmc, int random_seed, int max_depth +) { + // Create smart pointer to newly allocated object + std::unique_ptr bart_result_ptr_ = std::make_unique(num_trees, output_dimension, is_leaf_constant); + + // Convert variable weights to std::vector + std::vector var_weights_vector(variable_weights.size()); + for (int i = 0; i < variable_weights.size(); i++) { + var_weights_vector[i] = variable_weights[i]; + } + + // Convert feature types to std::vector + std::vector feature_types_vector(feature_types.size()); + for (int i = 0; i < feature_types.size(); i++) { + feature_types_vector[i] = static_cast(feature_types[i]); + } + + // Create BART dispatcher and add data + StochTree::BARTDispatcherSimplified bart_dispatcher{}; + double* covariate_data_ptr = REAL(PROTECT(covariates)); + double* outcome_data_ptr = REAL(PROTECT(outcome)); + bart_dispatcher.AddDataset(covariate_data_ptr, num_rows, num_covariates, false, true); + bart_dispatcher.AddTrainOutcome(outcome_data_ptr, num_rows); + + // Run the BART sampling loop + bart_dispatcher.RunSampler( + *bart_result_ptr_.get(), feature_types_vector, var_weights_vector, + num_trees, num_gfr, num_burnin, num_mcmc, global_variance_init, leaf_variance_init, + alpha, beta, nu, lamb, a_leaf, b_leaf, min_samples_leaf, cutpoint_grid_size, + random_seed, max_depth + ); + + // Unprotect pointers to R data + UNPROTECT(2); + + // Release management of the pointer to R session + return cpp11::external_pointer(bart_result_ptr_.release()); +} + diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 7fdcee69..413efea6 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -61,6 +61,13 @@ extern "C" SEXP _stochtree_run_bart_cpp_nobasis_notest_norfx(SEXP covariates_tra return cpp11::as_sexp(run_bart_cpp_nobasis_notest_norfx(cpp11::as_cpp>(covariates_train), cpp11::as_cpp>(outcome_train), cpp11::as_cpp>(num_rows_train), cpp11::as_cpp>(num_covariates_train), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_cov_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(sample_global_var), cpp11::as_cpp>(sample_leaf_var))); END_CPP11 } +// R_bart.cpp +cpp11::external_pointer run_bart_specialized_cpp(cpp11::doubles covariates, cpp11::doubles outcome, cpp11::integers feature_types, cpp11::doubles variable_weights, int num_rows, int num_covariates, int num_trees, int output_dimension, bool is_leaf_constant, double alpha, double beta, int min_samples_leaf, int cutpoint_grid_size, double a_leaf, double b_leaf, double nu, double lamb, double leaf_variance_init, double global_variance_init, int num_gfr, int num_burnin, int num_mcmc, int random_seed, int max_depth); +extern "C" SEXP _stochtree_run_bart_specialized_cpp(SEXP covariates, SEXP outcome, SEXP feature_types, SEXP variable_weights, SEXP num_rows, SEXP num_covariates, SEXP num_trees, SEXP output_dimension, SEXP is_leaf_constant, SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP cutpoint_grid_size, SEXP a_leaf, SEXP b_leaf, SEXP nu, SEXP lamb, SEXP leaf_variance_init, SEXP global_variance_init, SEXP num_gfr, SEXP num_burnin, SEXP num_mcmc, SEXP random_seed, SEXP max_depth) { + BEGIN_CPP11 + return cpp11::as_sexp(run_bart_specialized_cpp(cpp11::as_cpp>(covariates), cpp11::as_cpp>(outcome), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_rows), cpp11::as_cpp>(num_covariates), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(leaf_variance_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(max_depth))); + END_CPP11 +} // R_data.cpp cpp11::external_pointer create_forest_dataset_cpp(); extern "C" SEXP _stochtree_create_forest_dataset_cpp() { @@ -933,6 +940,7 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_run_bart_cpp_nobasis_notest_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_rfx, 37}, {"_stochtree_run_bart_cpp_nobasis_test_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_norfx, 30}, {"_stochtree_run_bart_cpp_nobasis_test_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_rfx, 44}, + {"_stochtree_run_bart_specialized_cpp", (DL_FUNC) &_stochtree_run_bart_specialized_cpp, 24}, {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 13}, {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 13}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 4}, From d90d74ff40d26243ffbf88ab03ebf5b60f5d094a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 19 Jul 2024 17:57:17 -0400 Subject: [PATCH 12/18] Debug script to compare implementations of the BART sampling loop --- tools/debug/cpp_loop_refactor.R | 102 ++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tools/debug/cpp_loop_refactor.R diff --git a/tools/debug/cpp_loop_refactor.R b/tools/debug/cpp_loop_refactor.R new file mode 100644 index 00000000..4605605c --- /dev/null +++ b/tools/debug/cpp_loop_refactor.R @@ -0,0 +1,102 @@ +# Load libraries +library(stochtree) +library(rnn) + +# Random seed +random_seed <- 1234 +set.seed(random_seed) + +# Fixed parameters +sample_size <- 10000 +alpha <- 1.0 +beta <- 0.1 +ntree <- 50 +num_iter <- 10 +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 10 +min_samples_leaf <- 5 +nu <- 3 +lambda <- NULL +q <- 0.9 +sigma2_init <- NULL +sample_tau <- F +sample_sigma <- T + +# Generate data, choice of DGPs: +# (1) the "deep interaction" classification DGP +# (2) partitioned linear model (with split variables and basis included as BART covariates) +dgp_num <- 1 +if (dgp_num == 1) { + # Initial DGP setup + n0 <- 50 + p <- 10 + n <- n0*(2^p) + k <- 2 + p1 <- 20 + noise <- 0.1 + + # Full factorial covariate reference frame + xtemp <- as.data.frame(as.factor(rep(0:(2^p-1),n0))) + xtemp1 <- rep(0:(2^p-1),n0) + x <- t(sapply(xtemp1,function(j) as.numeric(int2bin(j,p)))) + X_superset <- x*abs(rnorm(length(x))) - (1-x)*abs(rnorm(length(x))) + + # Generate outcome + M <- model.matrix(~.-1,data = xtemp) + M <- cbind(rep(1,n),M) + beta.true <- -10*abs(rnorm(ncol(M))) + beta.true[1] <- 0.5 + non_zero_betas <- c(1,sample(1:ncol(M), p1-1)) + beta.true[-non_zero_betas] <- 0 + Y <- M %*% beta.true + rnorm(n, 0, noise) + y_superset <- as.numeric(Y>0) + + # Downsample to desired n + subset_inds <- order(sample(1:nrow(X_superset), sample_size, replace = F)) + X <- X_superset[subset_inds,] + y <- y_superset[subset_inds] +} else if (dgp_num == 2) { + p <- 10 + snr <- 2 + X <- matrix(runif(sample_size*p), ncol = p) + f_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2]) + ) + noise_sd <- sd(f_X) / snr + y <- f_X + rnorm(sample_size, 0, noise_sd) +} else stop("dgp_num must be 1 or 2") + +# Switch between +# (1) the R-dispatched loop, +# (2) the "generalized" C++ sampling loop, and +# (3) the "streamlined" / "specialized" C++ sampling loop that only samples trees +# and sigma^2 (error variance parameter) +sampler_choice <- 1 +if (sampler_choice == 1) { + bart_obj <- stochtree::bart( + X_train = X, y_train = y, alpha = alpha, beta = beta, + min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, + sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, + num_burnin = num_burnin, num_mcmc = num_mcmc, sample_tau = sample_tau, + sample_sigma = sample_sigma, random_seed = random_seed + ) +} else if (sampler_choice == 2) { + bart_obj <- stochtree::bart_cpp_loop_generalized( + X_train = X, y_train = y, alpha = alpha, beta = beta, + min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, + sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, + num_burnin = num_burnin, num_mcmc = num_mcmc, sample_leaf_var = sample_tau, + sample_global_var = sample_sigma, random_seed = random_seed + ) +} else if (sampler_choice == 3) { + bart_obj <- stochtree::bart_cpp_loop_specialized( + X_train = X, y_train = y, alpha = alpha, beta = beta, + min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, + sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, + num_burnin = num_burnin, num_mcmc = num_mcmc, random_seed = random_seed + ) +} else stop("sampler_choice must be 1, 2, or 3") From f673cb54aaf0f336a641036c7b3b7dc35a0a7b81 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 20 Jul 2024 03:36:15 -0400 Subject: [PATCH 13/18] Added functions to inspect tree depth --- NAMESPACE | 2 + R/bart.R | 22 +++++++++ R/cpp11.R | 20 ++++++++ R/forest.R | 24 ++++++++++ include/stochtree/container.h | 13 +++++ include/stochtree/ensemble.h | 14 ++++++ include/stochtree/tree.h | 42 ++++++++++++++++ man/ForestSamples.Rd | 58 +++++++++++++++++++++++ man/average_max_depth_bart_generalized.Rd | 17 +++++++ man/average_max_depth_bart_specialized.Rd | 17 +++++++ src/R_bart.cpp | 9 ++++ src/cpp11.cpp | 40 ++++++++++++++++ src/forest.cpp | 15 ++++++ src/tree.cpp | 10 ++++ test/cpp/test_tree.cpp | 7 +++ tools/debug/cpp_loop_refactor.R | 57 ++++++++++++---------- 16 files changed, 342 insertions(+), 25 deletions(-) create mode 100644 man/average_max_depth_bart_generalized.Rd create mode 100644 man/average_max_depth_bart_specialized.Rd diff --git a/NAMESPACE b/NAMESPACE index 03201872..7fa0b4ef 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,6 +4,8 @@ S3method(getRandomEffectSamples,bartmodel) S3method(getRandomEffectSamples,bcf) S3method(predict,bartmodel) S3method(predict,bcf) +export(average_max_depth_bart_generalized) +export(average_max_depth_bart_specialized) export(bart) export(bart_cpp_loop_generalized) export(bart_cpp_loop_specialized) diff --git a/R/bart.R b/R/bart.R index d129e1cf..ef21c906 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1041,6 +1041,7 @@ bart_cpp_loop_generalized <- function( ) result <- list( # "forests" = forest_samples, + "bart_result" = bart_result_ptr, "model_params" = model_params # "y_hat_train" = y_hat_train, # "train_set_metadata" = X_train_metadata, @@ -1238,6 +1239,7 @@ bart_cpp_loop_specialized <- function( "sample_tau" = F ) result <- list( + "bart_result" = bart_result_ptr, "model_params" = model_params ) class(result) <- "bartcppsimplified" @@ -1310,3 +1312,23 @@ getRandomEffectSamples.bartmodel <- function(object, ...){ return(result) } + +#' Return the average max depth of all trees and all ensembles in a container of samples +#' +#' @param bart_result External pointer to a bart result object +#' +#' @return Average maximum depth +#' @export +average_max_depth_bart_generalized <- function(bart_result) { + average_max_depth_bart_generalized_cpp(bart_result) +} + +#' Return the average max depth of all trees and all ensembles in a container of samples +#' +#' @param bart_result External pointer to a bart result object +#' +#' @return Average maximum depth +#' @export +average_max_depth_bart_specialized <- function(bart_result) { + average_max_depth_bart_specialized_cpp(bart_result) +} diff --git a/R/cpp11.R b/R/cpp11.R index 9e546e78..44ec3b32 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -36,6 +36,14 @@ run_bart_specialized_cpp <- function(covariates, outcome, feature_types, variabl .Call(`_stochtree_run_bart_specialized_cpp`, covariates, outcome, feature_types, variable_weights, num_rows, num_covariates, num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lamb, leaf_variance_init, global_variance_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth) } +average_max_depth_bart_generalized_cpp <- function(bart_result) { + .Call(`_stochtree_average_max_depth_bart_generalized_cpp`, bart_result) +} + +average_max_depth_bart_specialized_cpp <- function(bart_result) { + .Call(`_stochtree_average_max_depth_bart_specialized_cpp`, bart_result) +} + create_forest_dataset_cpp <- function() { .Call(`_stochtree_create_forest_dataset_cpp`) } @@ -224,6 +232,18 @@ num_samples_forest_container_cpp <- function(forest_samples) { .Call(`_stochtree_num_samples_forest_container_cpp`, forest_samples) } +ensemble_tree_max_depth_forest_container_cpp <- function(forest_samples, ensemble_num, tree_num) { + .Call(`_stochtree_ensemble_tree_max_depth_forest_container_cpp`, forest_samples, ensemble_num, tree_num) +} + +ensemble_average_max_depth_forest_container_cpp <- function(forest_samples, ensemble_num) { + .Call(`_stochtree_ensemble_average_max_depth_forest_container_cpp`, forest_samples, ensemble_num) +} + +average_max_depth_forest_container_cpp <- function(forest_samples) { + .Call(`_stochtree_average_max_depth_forest_container_cpp`, forest_samples) +} + num_trees_forest_container_cpp <- function(forest_samples) { .Call(`_stochtree_num_trees_forest_container_cpp`, forest_samples) } diff --git a/R/forest.R b/R/forest.R index 8e5c0559..10616b45 100644 --- a/R/forest.R +++ b/R/forest.R @@ -162,6 +162,30 @@ ForestSamples <- R6::R6Class( #' @return Leaf node parameter size output_dimension = function() { return(output_dimension_forest_container_cpp(self$forest_container_ptr)) + }, + + #' @description + #' Maximum depth of a specific tree in a specific ensemble in a `ForestContainer` object + #' @param ensemble_num Ensemble number + #' @param tree_num Tree index within ensemble `ensemble_num` + #' @return Maximum leaf depth + ensemble_tree_max_depth = function(ensemble_num, tree_num) { + return(ensemble_tree_max_depth_forest_container_cpp(self$forest_container_ptr, ensemble_num, tree_num)) + }, + + #' @description + #' Average the maximum depth of each tree in a given ensemble in a `ForestContainer` object + #' @param ensemble_num Ensemble number + #' @return Average maximum depth + average_ensemble_max_depth = function(ensemble_num) { + return(ensemble_average_max_depth_forest_container_cpp(self$forest_container_ptr, ensemble_num)) + }, + + #' @description + #' Average the maximum depth of each tree in each ensemble in a `ForestContainer` object + #' @return Average maximum depth + average_max_depth = function() { + return(average_max_depth_forest_container_cpp(self$forest_container_ptr)) } ) ) diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 880f7564..b3a7d806 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -42,6 +42,19 @@ class ForestContainer { inline int32_t NumTrees() {return num_trees_;} inline int32_t NumTrees(int ensemble_num) {return forests_[ensemble_num]->NumTrees();} inline int32_t NumLeaves(int ensemble_num) {return forests_[ensemble_num]->NumLeaves();} + inline int32_t EnsembleTreeMaxDepth(int ensemble_num, int tree_num) {return forests_[ensemble_num]->TreeMaxDepth(tree_num);} + inline double EnsembleAverageMaxDepth(int ensemble_num) {return forests_[ensemble_num]->AverageMaxDepth();} + inline double AverageMaxDepth() { + double numerator = 0.; + double denominator = 0.; + for (int i = 0; i < num_samples_; i++) { + for (int j = 0; j < num_trees_; j++) { + numerator += static_cast(forests_[i]->TreeMaxDepth(j)); + denominator += 1.; + } + } + return numerator / denominator; + } inline int32_t OutputDimension() {return output_dimension_;} inline int32_t OutputDimension(int ensemble_num) {return forests_[ensemble_num]->OutputDimension();} inline bool IsLeafConstant() {return is_leaf_constant_;} diff --git a/include/stochtree/ensemble.h b/include/stochtree/ensemble.h index d6d2b660..72a47b7c 100644 --- a/include/stochtree/ensemble.h +++ b/include/stochtree/ensemble.h @@ -194,6 +194,20 @@ class TreeEnsemble { return is_leaf_constant_; } + inline int32_t TreeMaxDepth(int tree_num) { + return trees_[tree_num]->MaxLeafDepth(); + } + + inline double AverageMaxDepth() { + double numerator = 0.; + double denominator = 0.; + for (int i = 0; i < num_trees_; i++) { + numerator += static_cast(TreeMaxDepth(i)); + denominator += 1.; + } + return numerator / denominator; + } + inline bool AllRoots() { for (int i = 0; i < num_trees_; i++) { if (!trees_[i]->IsRoot()) { diff --git a/include/stochtree/tree.h b/include/stochtree/tree.h index bbacd561..6d9a11a3 100644 --- a/include/stochtree/tree.h +++ b/include/stochtree/tree.h @@ -300,6 +300,14 @@ class Tree { return parent_[nid] == kInvalidNodeId; } + /*! + * \brief Whether the node has been deleted + * \param nid ID of node being queried + */ + bool IsDeleted(std::int32_t nid) const { + return node_deleted_[nid]; + } + /*! * \brief Get leaf value of the leaf node * \param nid ID of node being queried @@ -326,6 +334,39 @@ class Tree { return leaf_vector_[offset_begin + dim_id]; } } + + /*! + * \brief Get maximum depth of all of the leaf nodes + */ + std::int32_t MaxLeafDepth() const { + std::int32_t max_depth = 0; + std::stack nodes; + std::stack node_depths; + nodes.push(kRoot); + node_depths.push(0); + auto &self = *this; + while (!nodes.empty()) { + auto nidx = nodes.top(); + nodes.pop(); + auto node_depth = node_depths.top(); + node_depths.pop(); + bool valid_node = !self.IsDeleted(nidx); + if (valid_node) { + if (node_depth > max_depth) max_depth = node_depth; + auto left = self.LeftChild(nidx); + auto right = self.RightChild(nidx); + if (left != Tree::kInvalidNodeId) { + nodes.push(left); + node_depths.push(node_depth+1); + } + if (right != Tree::kInvalidNodeId) { + nodes.push(right); + node_depths.push(node_depth+1); + } + } + } + return max_depth; + } /*! * \brief get leaf vector of the leaf node; useful for multi-output trees @@ -652,6 +693,7 @@ class Tree { std::vector split_index_; std::vector leaf_value_; std::vector threshold_; + std::vector node_deleted_; std::vector internal_nodes_; std::vector leaves_; std::vector leaf_parents_; diff --git a/man/ForestSamples.Rd b/man/ForestSamples.Rd index f3749373..b64b4faf 100644 --- a/man/ForestSamples.Rd +++ b/man/ForestSamples.Rd @@ -28,6 +28,9 @@ Wrapper around a C++ container of tree ensembles \item \href{#method-ForestSamples-num_samples}{\code{ForestSamples$num_samples()}} \item \href{#method-ForestSamples-num_trees}{\code{ForestSamples$num_trees()}} \item \href{#method-ForestSamples-output_dimension}{\code{ForestSamples$output_dimension()}} +\item \href{#method-ForestSamples-ensemble_tree_max_depth}{\code{ForestSamples$ensemble_tree_max_depth()}} +\item \href{#method-ForestSamples-average_ensemble_max_depth}{\code{ForestSamples$average_ensemble_max_depth()}} +\item \href{#method-ForestSamples-average_max_depth}{\code{ForestSamples$average_max_depth()}} } } \if{html}{\out{
}} @@ -275,4 +278,59 @@ Return output dimension of trees in a \code{ForestContainer} object Leaf node parameter size } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-ensemble_tree_max_depth}{}}} +\subsection{Method \code{ensemble_tree_max_depth()}}{ +Maximum depth of a specific tree in a specific ensemble in a \code{ForestContainer} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$ensemble_tree_max_depth(ensemble_num, tree_num)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{ensemble_num}}{Ensemble number} + +\item{\code{tree_num}}{Tree index within ensemble \code{ensemble_num}} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Maximum leaf depth +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-average_ensemble_max_depth}{}}} +\subsection{Method \code{average_ensemble_max_depth()}}{ +Average the maximum depth of each tree in a given ensemble in a \code{ForestContainer} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$average_ensemble_max_depth(ensemble_num)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{ensemble_num}}{Ensemble number} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Average maximum depth +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-average_max_depth}{}}} +\subsection{Method \code{average_max_depth()}}{ +Average the maximum depth of each tree in each ensemble in a \code{ForestContainer} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$average_max_depth()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Average maximum depth +} +} } diff --git a/man/average_max_depth_bart_generalized.Rd b/man/average_max_depth_bart_generalized.Rd new file mode 100644 index 00000000..972a6472 --- /dev/null +++ b/man/average_max_depth_bart_generalized.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{average_max_depth_bart_generalized} +\alias{average_max_depth_bart_generalized} +\title{Return the average max depth of all trees and all ensembles in a container of samples} +\usage{ +average_max_depth_bart_generalized(bart_result) +} +\arguments{ +\item{bart_result}{External pointer to a bart result object} +} +\value{ +Average maximum depth +} +\description{ +Return the average max depth of all trees and all ensembles in a container of samples +} diff --git a/man/average_max_depth_bart_specialized.Rd b/man/average_max_depth_bart_specialized.Rd new file mode 100644 index 00000000..c28515fb --- /dev/null +++ b/man/average_max_depth_bart_specialized.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/bart.R +\name{average_max_depth_bart_specialized} +\alias{average_max_depth_bart_specialized} +\title{Return the average max depth of all trees and all ensembles in a container of samples} +\usage{ +average_max_depth_bart_specialized(bart_result) +} +\arguments{ +\item{bart_result}{External pointer to a bart result object} +} +\value{ +Average maximum depth +} +\description{ +Return the average max depth of all trees and all ensembles in a container of samples +} diff --git a/src/R_bart.cpp b/src/R_bart.cpp index 0e87c948..4320eb19 100644 --- a/src/R_bart.cpp +++ b/src/R_bart.cpp @@ -1135,3 +1135,12 @@ cpp11::external_pointer run_bart_specialized_cp return cpp11::external_pointer(bart_result_ptr_.release()); } +[[cpp11::register]] +double average_max_depth_bart_generalized_cpp(cpp11::external_pointer bart_result) { + return bart_result->GetForests()->AverageMaxDepth(); +} + +[[cpp11::register]] +double average_max_depth_bart_specialized_cpp(cpp11::external_pointer bart_result) { + return (bart_result->GetForests()).AverageMaxDepth(); +} diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 413efea6..182fd94d 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -68,6 +68,20 @@ extern "C" SEXP _stochtree_run_bart_specialized_cpp(SEXP covariates, SEXP outcom return cpp11::as_sexp(run_bart_specialized_cpp(cpp11::as_cpp>(covariates), cpp11::as_cpp>(outcome), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(num_rows), cpp11::as_cpp>(num_covariates), cpp11::as_cpp>(num_trees), cpp11::as_cpp>(output_dimension), cpp11::as_cpp>(is_leaf_constant), cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>(a_leaf), cpp11::as_cpp>(b_leaf), cpp11::as_cpp>(nu), cpp11::as_cpp>(lamb), cpp11::as_cpp>(leaf_variance_init), cpp11::as_cpp>(global_variance_init), cpp11::as_cpp>(num_gfr), cpp11::as_cpp>(num_burnin), cpp11::as_cpp>(num_mcmc), cpp11::as_cpp>(random_seed), cpp11::as_cpp>(max_depth))); END_CPP11 } +// R_bart.cpp +double average_max_depth_bart_generalized_cpp(cpp11::external_pointer bart_result); +extern "C" SEXP _stochtree_average_max_depth_bart_generalized_cpp(SEXP bart_result) { + BEGIN_CPP11 + return cpp11::as_sexp(average_max_depth_bart_generalized_cpp(cpp11::as_cpp>>(bart_result))); + END_CPP11 +} +// R_bart.cpp +double average_max_depth_bart_specialized_cpp(cpp11::external_pointer bart_result); +extern "C" SEXP _stochtree_average_max_depth_bart_specialized_cpp(SEXP bart_result) { + BEGIN_CPP11 + return cpp11::as_sexp(average_max_depth_bart_specialized_cpp(cpp11::as_cpp>>(bart_result))); + END_CPP11 +} // R_data.cpp cpp11::external_pointer create_forest_dataset_cpp(); extern "C" SEXP _stochtree_create_forest_dataset_cpp() { @@ -412,6 +426,27 @@ extern "C" SEXP _stochtree_num_samples_forest_container_cpp(SEXP forest_samples) END_CPP11 } // forest.cpp +int ensemble_tree_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples, int ensemble_num, int tree_num); +extern "C" SEXP _stochtree_ensemble_tree_max_depth_forest_container_cpp(SEXP forest_samples, SEXP ensemble_num, SEXP tree_num) { + BEGIN_CPP11 + return cpp11::as_sexp(ensemble_tree_max_depth_forest_container_cpp(cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>(ensemble_num), cpp11::as_cpp>(tree_num))); + END_CPP11 +} +// forest.cpp +double ensemble_average_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples, int ensemble_num); +extern "C" SEXP _stochtree_ensemble_average_max_depth_forest_container_cpp(SEXP forest_samples, SEXP ensemble_num) { + BEGIN_CPP11 + return cpp11::as_sexp(ensemble_average_max_depth_forest_container_cpp(cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>(ensemble_num))); + END_CPP11 +} +// forest.cpp +double average_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples); +extern "C" SEXP _stochtree_average_max_depth_forest_container_cpp(SEXP forest_samples) { + BEGIN_CPP11 + return cpp11::as_sexp(average_max_depth_forest_container_cpp(cpp11::as_cpp>>(forest_samples))); + END_CPP11 +} +// forest.cpp int num_trees_forest_container_cpp(cpp11::external_pointer forest_samples); extern "C" SEXP _stochtree_num_trees_forest_container_cpp(SEXP forest_samples) { BEGIN_CPP11 @@ -838,6 +873,9 @@ extern "C" { static const R_CallMethodDef CallEntries[] = { {"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1}, {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, + {"_stochtree_average_max_depth_bart_generalized_cpp", (DL_FUNC) &_stochtree_average_max_depth_bart_generalized_cpp, 1}, + {"_stochtree_average_max_depth_bart_specialized_cpp", (DL_FUNC) &_stochtree_average_max_depth_bart_specialized_cpp, 1}, + {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, {"_stochtree_create_forest_dataset_cpp", (DL_FUNC) &_stochtree_create_forest_dataset_cpp, 0}, {"_stochtree_create_rfx_dataset_cpp", (DL_FUNC) &_stochtree_create_rfx_dataset_cpp, 0}, @@ -846,6 +884,8 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_dataset_num_basis_cpp, 1}, {"_stochtree_dataset_num_covariates_cpp", (DL_FUNC) &_stochtree_dataset_num_covariates_cpp, 1}, {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, + {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, + {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 3}, {"_stochtree_forest_container_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_cpp, 2}, {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, diff --git a/src/forest.cpp b/src/forest.cpp index 7f00072a..b6758bad 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -41,6 +41,21 @@ int num_samples_forest_container_cpp(cpp11::external_pointerNumSamples(); } +[[cpp11::register]] +int ensemble_tree_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples, int ensemble_num, int tree_num) { + return forest_samples->EnsembleTreeMaxDepth(ensemble_num, tree_num); +} + +[[cpp11::register]] +double ensemble_average_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples, int ensemble_num) { + return forest_samples->EnsembleAverageMaxDepth(ensemble_num); +} + +[[cpp11::register]] +double average_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples) { + return forest_samples->AverageMaxDepth(); +} + [[cpp11::register]] int num_trees_forest_container_cpp(cpp11::external_pointer forest_samples) { return forest_samples->NumTrees(); diff --git a/src/tree.cpp b/src/tree.cpp index 68ac6327..3dbf227c 100644 --- a/src/tree.cpp +++ b/src/tree.cpp @@ -96,6 +96,7 @@ void Tree::CloneFromTree(Tree* tree) { split_index_ = tree->split_index_; leaf_value_ = tree->leaf_value_; threshold_ = tree->threshold_; + node_deleted_ = tree->node_deleted_; internal_nodes_ = tree->internal_nodes_; leaves_ = tree->leaves_; leaf_parents_ = tree->leaf_parents_; @@ -116,6 +117,7 @@ std::int32_t Tree::AllocNode() { // Reuse a "deleted" node if available if (num_deleted_nodes != 0) { std::int32_t nid = deleted_nodes_.back(); + node_deleted_[nid] = false; deleted_nodes_.pop_back(); --num_deleted_nodes; return nid; @@ -130,6 +132,7 @@ std::int32_t Tree::AllocNode() { split_index_.push_back(-1); leaf_value_.push_back(static_cast(0)); threshold_.push_back(static_cast(0)); + node_deleted_.push_back(false); // THIS is a placeholder, currently set after AllocNode is called ... // ... to be refactored ... parent_.push_back(static_cast(0)); @@ -154,6 +157,7 @@ void Tree::DeleteNode(std::int32_t nid) { deleted_nodes_.push_back(nid); ++num_deleted_nodes; + node_deleted_[nid] = true; // Remove from vectors that track leaves, leaf parents, internal nodes, etc... leaves_.erase(std::remove(leaves_.begin(), leaves_.end(), nid), leaves_.end()); @@ -296,6 +300,7 @@ void Tree::Reset() { leaf_value_.clear(); threshold_.clear(); parent_.clear(); + node_deleted_.clear(); num_nodes = 0; has_categorical_split_ = false; @@ -329,6 +334,7 @@ void Tree::Init(std::int32_t output_dimension) { leaf_value_.clear(); threshold_.clear(); parent_.clear(); + node_deleted_.clear(); num_nodes = 0; has_categorical_split_ = false; @@ -462,6 +468,7 @@ void TreeNodeVectorsToJson(json& obj, Tree* tree) { tree_array_map.emplace(std::pair("split_index", json::array())); tree_array_map.emplace(std::pair("leaf_value", json::array())); tree_array_map.emplace(std::pair("threshold", json::array())); + tree_array_map.emplace(std::pair("node_deleted", json::array())); tree_array_map.emplace(std::pair("leaf_vector_begin", json::array())); tree_array_map.emplace(std::pair("leaf_vector_end", json::array())); tree_array_map.emplace(std::pair("category_list_begin", json::array())); @@ -480,6 +487,7 @@ void TreeNodeVectorsToJson(json& obj, Tree* tree) { tree_array_map["split_index"].emplace_back(tree->split_index_[i]); tree_array_map["leaf_value"].emplace_back(tree->leaf_value_[i]); tree_array_map["threshold"].emplace_back(tree->threshold_[i]); + tree_array_map["node_deleted"].emplace_back(tree->node_deleted_[i]); tree_array_map["leaf_vector_begin"].emplace_back(static_cast(tree->leaf_vector_begin_[i])); tree_array_map["leaf_vector_end"].emplace_back(static_cast(tree->leaf_vector_end_[i])); tree_array_map["category_list_begin"].emplace_back(static_cast(tree->category_list_begin_[i])); @@ -574,6 +582,7 @@ void JsonToTreeNodeVectors(const json& tree_json, Tree* tree) { tree->split_index_.clear(); tree->leaf_value_.clear(); tree->threshold_.clear(); + tree->node_deleted_.clear(); tree->node_type_.clear(); tree->leaf_vector_begin_.clear(); tree->leaf_vector_end_.clear(); @@ -588,6 +597,7 @@ void JsonToTreeNodeVectors(const json& tree_json, Tree* tree) { tree->split_index_.push_back(tree_json.at("split_index").at(i)); tree->leaf_value_.push_back(tree_json.at("leaf_value").at(i)); tree->threshold_.push_back(tree_json.at("threshold").at(i)); + tree->node_deleted_.push_back(tree_json.at("node_deleted").at(i)); // Handle type conversions for node_type, leaf_vector_begin/end, and category_list_begin/end tree->node_type_.push_back(static_cast(tree_json.at("node_type").at(i))); tree->leaf_vector_begin_.push_back(static_cast(tree_json.at("leaf_vector_begin").at(i))); diff --git a/test/cpp/test_tree.cpp b/test/cpp/test_tree.cpp index ba8c39bb..e9ccff7d 100644 --- a/test/cpp/test_tree.cpp +++ b/test/cpp/test_tree.cpp @@ -33,12 +33,17 @@ TEST(Tree, UnivariateTreeCopyConstruction) { StochTree::Tree tree_2; StochTree::TreeSplit split; tree_1.Init(1); + + // Check max depth + ASSERT_EQ(tree_1.MaxLeafDepth(), 0); // Perform two splits split = StochTree::TreeSplit(0.5); tree_1.ExpandNode(0, 0, split, 0., 0.); + ASSERT_EQ(tree_1.MaxLeafDepth(), 1); split = StochTree::TreeSplit(0.75); tree_1.ExpandNode(1, 1, split, 0., 0.); + ASSERT_EQ(tree_1.MaxLeafDepth(), 2); ASSERT_EQ(tree_1.NumValidNodes(), 5); ASSERT_EQ(tree_1.NumLeafParents(), 1); @@ -56,6 +61,7 @@ TEST(Tree, UnivariateTreeCopyConstruction) { // Perform another split split = StochTree::TreeSplit(0.6); tree_1.ExpandNode(3, 2, split, 0., 0.); + ASSERT_EQ(tree_1.MaxLeafDepth(), 3); ASSERT_EQ(tree_1.NumValidNodes(), 7); ASSERT_EQ(tree_1.NumLeaves(), 4); ASSERT_EQ(tree_1.NumLeafParents(), 1); @@ -73,6 +79,7 @@ TEST(Tree, UnivariateTreeCopyConstruction) { // Prune node 3 to a leaf tree_1.CollapseToLeaf(3, 0.); + ASSERT_EQ(tree_1.MaxLeafDepth(), 2); ASSERT_EQ(tree_1.NumValidNodes(), 5); ASSERT_EQ(tree_1.NumLeaves(), 3); ASSERT_EQ(tree_1.NumLeafParents(), 1); diff --git a/tools/debug/cpp_loop_refactor.R b/tools/debug/cpp_loop_refactor.R index 4605605c..a8cf15a5 100644 --- a/tools/debug/cpp_loop_refactor.R +++ b/tools/debug/cpp_loop_refactor.R @@ -26,7 +26,7 @@ sample_sigma <- T # Generate data, choice of DGPs: # (1) the "deep interaction" classification DGP # (2) partitioned linear model (with split variables and basis included as BART covariates) -dgp_num <- 1 +dgp_num <- 2 if (dgp_num == 1) { # Initial DGP setup n0 <- 50 @@ -76,27 +76,34 @@ if (dgp_num == 1) { # (3) the "streamlined" / "specialized" C++ sampling loop that only samples trees # and sigma^2 (error variance parameter) sampler_choice <- 1 -if (sampler_choice == 1) { - bart_obj <- stochtree::bart( - X_train = X, y_train = y, alpha = alpha, beta = beta, - min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, - sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, - num_burnin = num_burnin, num_mcmc = num_mcmc, sample_tau = sample_tau, - sample_sigma = sample_sigma, random_seed = random_seed - ) -} else if (sampler_choice == 2) { - bart_obj <- stochtree::bart_cpp_loop_generalized( - X_train = X, y_train = y, alpha = alpha, beta = beta, - min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, - sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, - num_burnin = num_burnin, num_mcmc = num_mcmc, sample_leaf_var = sample_tau, - sample_global_var = sample_sigma, random_seed = random_seed - ) -} else if (sampler_choice == 3) { - bart_obj <- stochtree::bart_cpp_loop_specialized( - X_train = X, y_train = y, alpha = alpha, beta = beta, - min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, - sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, - num_burnin = num_burnin, num_mcmc = num_mcmc, random_seed = random_seed - ) -} else stop("sampler_choice must be 1, 2, or 3") +system.time({ + if (sampler_choice == 1) { + bart_obj <- stochtree::bart( + X_train = X, y_train = y, alpha = alpha, beta = beta, + min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, + sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, + num_burnin = num_burnin, num_mcmc = num_mcmc, sample_tau = sample_tau, + sample_sigma = sample_sigma, random_seed = random_seed + ) + avg_md <- bart_obj$forests$average_max_depth() + } else if (sampler_choice == 2) { + bart_obj <- stochtree::bart_cpp_loop_generalized( + X_train = X, y_train = y, alpha = alpha, beta = beta, + min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, + sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, + num_burnin = num_burnin, num_mcmc = num_mcmc, sample_leaf_var = sample_tau, + sample_global_var = sample_sigma, random_seed = random_seed + ) + avg_md <- average_max_depth_bart_generalized(bart_obj$bart_result) + } else if (sampler_choice == 3) { + bart_obj <- stochtree::bart_cpp_loop_specialized( + X_train = X, y_train = y, alpha = alpha, beta = beta, + min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, + sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, + num_burnin = num_burnin, num_mcmc = num_mcmc, random_seed = random_seed + ) + avg_md <- average_max_depth_bart_specialized(bart_obj$bart_result) + } else stop("sampler_choice must be 1, 2, or 3") +}) + +avg_md \ No newline at end of file From 78606c1b75d52759132c225b961aefd308be2a87 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 20 Jul 2024 03:42:44 -0400 Subject: [PATCH 14/18] Fixed bug in "specialized" BART loop --- R/bart.R | 2 +- tools/debug/cpp_loop_refactor.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/bart.R b/R/bart.R index ef21c906..70c1b4b1 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1205,7 +1205,7 @@ bart_cpp_loop_specialized <- function( # Run the BART sampler bart_result_ptr <- run_bart_specialized_cpp( - as.numeric(X_train), y_train, feature_types, variable_weights, nrow(X_train), + as.numeric(X_train), resid_train, feature_types, variable_weights, nrow(X_train), ncol(X_train), num_trees, output_dimension, is_leaf_constant, alpha, beta, min_samples_leaf, cutpoint_grid_size, a_leaf, b_leaf, nu, lambda, tau_init, sigma2_init, num_gfr, num_burnin, num_mcmc, random_seed, max_depth diff --git a/tools/debug/cpp_loop_refactor.R b/tools/debug/cpp_loop_refactor.R index a8cf15a5..99e3edd3 100644 --- a/tools/debug/cpp_loop_refactor.R +++ b/tools/debug/cpp_loop_refactor.R @@ -75,7 +75,7 @@ if (dgp_num == 1) { # (2) the "generalized" C++ sampling loop, and # (3) the "streamlined" / "specialized" C++ sampling loop that only samples trees # and sigma^2 (error variance parameter) -sampler_choice <- 1 +sampler_choice <- 3 system.time({ if (sampler_choice == 1) { bart_obj <- stochtree::bart( From f201db687a7f07e9d3787bc7a06391e184403e78 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 20 Jul 2024 03:58:16 -0400 Subject: [PATCH 15/18] Fixed max_depth bug --- R/bart.R | 2 +- include/stochtree/bart.h | 2 +- tools/debug/cpp_loop_refactor.R | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/bart.R b/R/bart.R index 70c1b4b1..8a54a0df 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1118,7 +1118,7 @@ bart_cpp_loop_generalized <- function( bart_cpp_loop_specialized <- function( X_train, y_train, X_test = NULL, cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95, beta = 2.0, min_samples_leaf = 5, - max_depth = -1, nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, + max_depth = 10, nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL, q = 0.9, sigma2_init = NULL, variable_weights = NULL, num_trees = 200, num_gfr = 5, num_burnin = 0, num_mcmc = 100, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F diff --git a/include/stochtree/bart.h b/include/stochtree/bart.h index 1441cafb..5f0fc35e 100644 --- a/include/stochtree/bart.h +++ b/include/stochtree/bart.h @@ -406,7 +406,7 @@ class BARTDispatcherSimplified { // Initialize tracker and tree prior ForestTracker tracker = ForestTracker(train_dataset_.GetCovariates(), feature_types, num_trees, num_train_); - TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf); + TreePrior tree_prior = TreePrior(alpha, beta, min_samples_leaf, max_depth); // Initialize variance model GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); diff --git a/tools/debug/cpp_loop_refactor.R b/tools/debug/cpp_loop_refactor.R index 99e3edd3..a8cf15a5 100644 --- a/tools/debug/cpp_loop_refactor.R +++ b/tools/debug/cpp_loop_refactor.R @@ -75,7 +75,7 @@ if (dgp_num == 1) { # (2) the "generalized" C++ sampling loop, and # (3) the "streamlined" / "specialized" C++ sampling loop that only samples trees # and sigma^2 (error variance parameter) -sampler_choice <- 3 +sampler_choice <- 1 system.time({ if (sampler_choice == 1) { bart_obj <- stochtree::bart( From 449b3f0c085d581e96a85a431f7ae914279264e9 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 24 Jul 2024 01:44:02 -0400 Subject: [PATCH 16/18] Updated samplers and tests --- R/model.R | 4 +-- man/ForestModel.Rd | 4 +-- man/ForestSamples.Rd | 62 ++++++++++++++++++++++++++++++-- man/bart_cpp_loop_specialized.Rd | 2 +- src/cpp11.cpp | 16 ++++++++- src/py_stochtree.cpp | 2 +- stochtree/sampler.py | 2 +- test/R/testthat/test-residual.R | 3 +- 8 files changed, 84 insertions(+), 11 deletions(-) diff --git a/R/model.R b/R/model.R index ecebd206..284adcca 100644 --- a/R/model.R +++ b/R/model.R @@ -50,9 +50,9 @@ ForestModel <- R6::R6Class( #' @param alpha Root node split probability in tree prior #' @param beta Depth prior penalty in tree prior #' @param min_samples_leaf Minimum number of samples in a tree leaf - #' @param max_depth Maximum depth of any tree in an ensemble + #' @param max_depth Maximum depth of any tree in an ensemble. Default: `-1`. #' @return A new `ForestModel` object. - initialize = function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) { + initialize = function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth = -1) { stopifnot(!is.null(forest_dataset$data_ptr)) self$tracker_ptr <- forest_tracker_cpp(forest_dataset$data_ptr, feature_types, num_trees, n) self$tree_prior_ptr <- tree_prior_cpp(alpha, beta, min_samples_leaf, max_depth) diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index 3a8cbc6b..82bd4337 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -38,7 +38,7 @@ Create a new ForestModel object. alpha, beta, min_samples_leaf, - max_depth + max_depth = -1 )}\if{html}{\out{}} } @@ -59,7 +59,7 @@ Create a new ForestModel object. \item{\code{min_samples_leaf}}{Minimum number of samples in a tree leaf} -\item{\code{max_depth}}{Maximum depth of any tree in an ensemble} +\item{\code{max_depth}}{Maximum depth of any tree in an ensemble. Default: \code{-1}.} } \if{html}{\out{}} } diff --git a/man/ForestSamples.Rd b/man/ForestSamples.Rd index 683473fb..b629ca1a 100644 --- a/man/ForestSamples.Rd +++ b/man/ForestSamples.Rd @@ -36,6 +36,9 @@ Wrapper around a C++ container of tree ensembles \item \href{#method-ForestSamples-get_forest_split_counts}{\code{ForestSamples$get_forest_split_counts()}} \item \href{#method-ForestSamples-get_aggregate_split_counts}{\code{ForestSamples$get_aggregate_split_counts()}} \item \href{#method-ForestSamples-get_granular_split_counts}{\code{ForestSamples$get_granular_split_counts()}} +\item \href{#method-ForestSamples-ensemble_tree_max_depth}{\code{ForestSamples$ensemble_tree_max_depth()}} +\item \href{#method-ForestSamples-average_ensemble_max_depth}{\code{ForestSamples$average_ensemble_max_depth()}} +\item \href{#method-ForestSamples-average_max_depth}{\code{ForestSamples$average_max_depth()}} } } \if{html}{\out{
}} @@ -434,7 +437,7 @@ Retrieve a vector of split counts for every training set variable in a given for \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestSamples-get_aggregate_split_counts}{}}} \subsection{Method \code{get_aggregate_split_counts()}}{ -Retrieve a vector of split counts for every training set variable in a given forest +Retrieve a vector of split counts for every training set variable in a given forest, aggregated across ensembles and trees \subsection{Usage}{ \if{html}{\out{
}}\preformatted{ForestSamples$get_aggregate_split_counts(num_features)}\if{html}{\out{
}} } @@ -451,7 +454,7 @@ Retrieve a vector of split counts for every training set variable in a given for \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ForestSamples-get_granular_split_counts}{}}} \subsection{Method \code{get_granular_split_counts()}}{ -Retrieve a vector of split counts for every training set variable in a given forest +Retrieve a vector of split counts for every training set variable in a given forest, reported separately for each ensemble and tree \subsection{Usage}{ \if{html}{\out{
}}\preformatted{ForestSamples$get_granular_split_counts(num_features)}\if{html}{\out{
}} } @@ -464,4 +467,59 @@ Retrieve a vector of split counts for every training set variable in a given for \if{html}{\out{}} } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-ensemble_tree_max_depth}{}}} +\subsection{Method \code{ensemble_tree_max_depth()}}{ +Maximum depth of a specific tree in a specific ensemble in a \code{ForestContainer} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$ensemble_tree_max_depth(ensemble_num, tree_num)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{ensemble_num}}{Ensemble number} + +\item{\code{tree_num}}{Tree index within ensemble \code{ensemble_num}} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Maximum leaf depth +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-average_ensemble_max_depth}{}}} +\subsection{Method \code{average_ensemble_max_depth()}}{ +Average the maximum depth of each tree in a given ensemble in a \code{ForestContainer} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$average_ensemble_max_depth(ensemble_num)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{ensemble_num}}{Ensemble number} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Average maximum depth +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestSamples-average_max_depth}{}}} +\subsection{Method \code{average_max_depth()}}{ +Average the maximum depth of each tree in each ensemble in a \code{ForestContainer} object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestSamples$average_max_depth()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Average maximum depth +} +} } diff --git a/man/bart_cpp_loop_specialized.Rd b/man/bart_cpp_loop_specialized.Rd index 03066106..4568afa1 100644 --- a/man/bart_cpp_loop_specialized.Rd +++ b/man/bart_cpp_loop_specialized.Rd @@ -13,7 +13,7 @@ bart_cpp_loop_specialized( alpha = 0.95, beta = 2, min_samples_leaf = 5, - max_depth = -1, + max_depth = 10, nu = 3, lambda = NULL, a_leaf = 3, diff --git a/src/cpp11.cpp b/src/cpp11.cpp index c769fd8b..234c7bcb 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -960,6 +960,9 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_add_sample_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp, 2}, {"_stochtree_adjust_residual_forest_container_cpp", (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp, 7}, {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, + {"_stochtree_average_max_depth_bart_generalized_cpp", (DL_FUNC) &_stochtree_average_max_depth_bart_generalized_cpp, 1}, + {"_stochtree_average_max_depth_bart_specialized_cpp", (DL_FUNC) &_stochtree_average_max_depth_bart_specialized_cpp, 1}, + {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, {"_stochtree_create_forest_dataset_cpp", (DL_FUNC) &_stochtree_create_forest_dataset_cpp, 0}, {"_stochtree_create_rfx_dataset_cpp", (DL_FUNC) &_stochtree_create_rfx_dataset_cpp, 0}, @@ -968,6 +971,8 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_dataset_num_basis_cpp, 1}, {"_stochtree_dataset_num_covariates_cpp", (DL_FUNC) &_stochtree_dataset_num_covariates_cpp, 1}, {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, + {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, + {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 3}, {"_stochtree_forest_container_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_cpp, 2}, {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, @@ -1060,13 +1065,22 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, + {"_stochtree_run_bart_cpp_basis_notest_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_norfx, 29}, + {"_stochtree_run_bart_cpp_basis_notest_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_notest_rfx, 39}, + {"_stochtree_run_bart_cpp_basis_test_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_norfx, 34}, + {"_stochtree_run_bart_cpp_basis_test_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_basis_test_rfx, 48}, + {"_stochtree_run_bart_cpp_nobasis_notest_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_norfx, 27}, + {"_stochtree_run_bart_cpp_nobasis_notest_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_notest_rfx, 37}, + {"_stochtree_run_bart_cpp_nobasis_test_norfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_norfx, 30}, + {"_stochtree_run_bart_cpp_nobasis_test_rfx", (DL_FUNC) &_stochtree_run_bart_cpp_nobasis_test_rfx, 44}, + {"_stochtree_run_bart_specialized_cpp", (DL_FUNC) &_stochtree_run_bart_specialized_cpp, 24}, {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 13}, {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 13}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 4}, {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 5}, {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, - {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 3}, + {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, {"_stochtree_update_residual_forest_container_cpp", (DL_FUNC) &_stochtree_update_residual_forest_container_cpp, 5}, {NULL, NULL, 0} }; diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index d506cbaf..fd8e3275 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -435,7 +435,7 @@ class ForestContainerCpp { class ForestSamplerCpp { public: - ForestSamplerCpp(ForestDatasetCpp& dataset, py::array_t feature_types, int num_trees, data_size_t num_obs, double alpha, double beta, int min_samples_leaf, int max_depth = -1) { + ForestSamplerCpp(ForestDatasetCpp& dataset, py::array_t feature_types, int num_trees, data_size_t num_obs, double alpha, double beta, int min_samples_leaf, int max_depth) { // Convert vector of integers to std::vector of enum FeatureType std::vector feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { diff --git a/stochtree/sampler.py b/stochtree/sampler.py index 5c013a6a..21977619 100644 --- a/stochtree/sampler.py +++ b/stochtree/sampler.py @@ -13,7 +13,7 @@ def __init__(self, random_seed: int) -> None: class ForestSampler: - def __init__(self, dataset: Dataset, feature_types: np.array, num_trees: int, num_obs: int, alpha: float, beta: float, min_samples_leaf: int, max_depth: int) -> None: + def __init__(self, dataset: Dataset, feature_types: np.array, num_trees: int, num_obs: int, alpha: float, beta: float, min_samples_leaf: int, max_depth: int = -1) -> None: # Initialize a ForestDatasetCpp object self.forest_sampler_cpp = ForestSamplerCpp(dataset.dataset_cpp, feature_types, num_trees, num_obs, alpha, beta, min_samples_leaf, max_depth) diff --git a/test/R/testthat/test-residual.R b/test/R/testthat/test-residual.R index f0f3e20d..4eb0fe3a 100644 --- a/test/R/testthat/test-residual.R +++ b/test/R/testthat/test-residual.R @@ -28,12 +28,13 @@ test_that("Residual updates correctly propagated after forest sampling step", { current_sigma2 = 1. current_leaf_scale = as.matrix(1./num_trees,nrow=1,ncol=1) cutpoint_grid_size = 100 + max_depth = 10 # RNG cpp_rng = createRNG(-1) # Create forest sampler and forest container - forest_model = createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf) + forest_model = createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) forest_samples = createForestContainer(num_trees, 1, F) # Initialize the leaves of each tree in the prognostic forest From be90ccc4573bad50a4437a16b897e11a79c9160d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 24 Jul 2024 01:50:46 -0400 Subject: [PATCH 17/18] Updated pybind initializer for ForestSamplerCpp --- src/py_stochtree.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index fd8e3275..3c8ca606 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -939,7 +939,7 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("GetGranularSplitCounts", &ForestContainerCpp::GetGranularSplitCounts); py::class_(m, "ForestSamplerCpp") - .def(py::init, int, data_size_t, double, double, int>()) + .def(py::init, int, data_size_t, double, double, int, int>()) .def("SampleOneIteration", &ForestSamplerCpp::SampleOneIteration); py::class_(m, "GlobalVarianceModelCpp") From ee05f2290987554fe0ed8e62fe4c6e143e6d6891 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 24 Jul 2024 15:44:14 -0400 Subject: [PATCH 18/18] Updated R package docs --- _pkgdown.yml | 4 ++++ include/stochtree/container.h | 13 ------------- man/ForestModel.Rd | 2 -- src/cpp11.cpp | 27 --------------------------- 4 files changed, 4 insertions(+), 42 deletions(-) diff --git a/_pkgdown.yml b/_pkgdown.yml index bffe900a..dd92343f 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -70,6 +70,10 @@ reference: - createForestKernel - CppRNG - createRNG + - average_max_depth_bart_generalized + - average_max_depth_bart_specialized + - bart_cpp_loop_generalized + - bart_cpp_loop_specialized - subtitle: Random Effects desc: > diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 4191dce8..b3a7d806 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -55,19 +55,6 @@ class ForestContainer { } return numerator / denominator; } - inline int32_t EnsembleTreeMaxDepth(int ensemble_num, int tree_num) {return forests_[ensemble_num]->TreeMaxDepth(tree_num);} - inline double EnsembleAverageMaxDepth(int ensemble_num) {return forests_[ensemble_num]->AverageMaxDepth();} - inline double AverageMaxDepth() { - double numerator = 0.; - double denominator = 0.; - for (int i = 0; i < num_samples_; i++) { - for (int j = 0; j < num_trees_; j++) { - numerator += static_cast(forests_[i]->TreeMaxDepth(j)); - denominator += 1.; - } - } - return numerator / denominator; - } inline int32_t OutputDimension() {return output_dimension_;} inline int32_t OutputDimension(int ensemble_num) {return forests_[ensemble_num]->OutputDimension();} inline bool IsLeafConstant() {return is_leaf_constant_;} diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index 1898410f..82bd4337 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -39,8 +39,6 @@ Create a new ForestModel object. beta, min_samples_leaf, max_depth = -1 - min_samples_leaf, - max_depth = -1 )}\if{html}{\out{}} } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index c60ea3e8..234c7bcb 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -454,27 +454,6 @@ extern "C" SEXP _stochtree_average_max_depth_forest_container_cpp(SEXP forest_sa END_CPP11 } // forest.cpp -int ensemble_tree_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples, int ensemble_num, int tree_num); -extern "C" SEXP _stochtree_ensemble_tree_max_depth_forest_container_cpp(SEXP forest_samples, SEXP ensemble_num, SEXP tree_num) { - BEGIN_CPP11 - return cpp11::as_sexp(ensemble_tree_max_depth_forest_container_cpp(cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>(ensemble_num), cpp11::as_cpp>(tree_num))); - END_CPP11 -} -// forest.cpp -double ensemble_average_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples, int ensemble_num); -extern "C" SEXP _stochtree_ensemble_average_max_depth_forest_container_cpp(SEXP forest_samples, SEXP ensemble_num) { - BEGIN_CPP11 - return cpp11::as_sexp(ensemble_average_max_depth_forest_container_cpp(cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>(ensemble_num))); - END_CPP11 -} -// forest.cpp -double average_max_depth_forest_container_cpp(cpp11::external_pointer forest_samples); -extern "C" SEXP _stochtree_average_max_depth_forest_container_cpp(SEXP forest_samples) { - BEGIN_CPP11 - return cpp11::as_sexp(average_max_depth_forest_container_cpp(cpp11::as_cpp>>(forest_samples))); - END_CPP11 -} -// forest.cpp int num_trees_forest_container_cpp(cpp11::external_pointer forest_samples); extern "C" SEXP _stochtree_num_trees_forest_container_cpp(SEXP forest_samples) { BEGIN_CPP11 @@ -736,12 +715,9 @@ extern "C" SEXP _stochtree_rng_cpp(SEXP random_seed) { } // sampler.cpp cpp11::external_pointer tree_prior_cpp(double alpha, double beta, int min_samples_leaf, int max_depth); -extern "C" SEXP _stochtree_tree_prior_cpp(SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP max_depth) { -cpp11::external_pointer tree_prior_cpp(double alpha, double beta, int min_samples_leaf, int max_depth); extern "C" SEXP _stochtree_tree_prior_cpp(SEXP alpha, SEXP beta, SEXP min_samples_leaf, SEXP max_depth) { BEGIN_CPP11 return cpp11::as_sexp(tree_prior_cpp(cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth))); - return cpp11::as_sexp(tree_prior_cpp(cpp11::as_cpp>(alpha), cpp11::as_cpp>(beta), cpp11::as_cpp>(min_samples_leaf), cpp11::as_cpp>(max_depth))); END_CPP11 } // sampler.cpp @@ -997,8 +973,6 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, - {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, - {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 3}, {"_stochtree_forest_container_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_cpp, 2}, {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, @@ -1107,7 +1081,6 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, - {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, {"_stochtree_update_residual_forest_container_cpp", (DL_FUNC) &_stochtree_update_residual_forest_container_cpp, 5}, {NULL, NULL, 0} };