diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml
index 150ed5e7a..135e8834f 100644
--- a/.github/workflows/R-CMD-check.yaml
+++ b/.github/workflows/R-CMD-check.yaml
@@ -67,11 +67,10 @@ jobs:
with:
python-version: 3.11
- - name: Install TensorFlow
+ - name: Install TensorFlow/Keras
run: |
- reticulate::virtualenv_create('r-reticulate', python='3.11')
- reticulate::use_virtualenv('r-reticulate')
- tensorflow::install_tensorflow(version='2.16')
+ install.packages(c("keras3"))
+ keras3::install_keras()
shell: Rscript {0}
- uses: r-lib/actions/check-r-package@v2
diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml
index 6e56ef9a4..b29fc83bf 100644
--- a/.github/workflows/test-coverage.yaml
+++ b/.github/workflows/test-coverage.yaml
@@ -27,19 +27,10 @@ jobs:
extra-packages: any::covr, any::xml2
needs: coverage
- - name: Install dev reticulate
- run: pak::pkg_install('rstudio/reticulate')
- shell: Rscript {0}
-
- - name: Install Miniconda
- run: |
- reticulate::install_miniconda()
- shell: Rscript {0}
-
- - name: Install TensorFlow
+ - name: Install TensorFlow/Keras
run: |
- reticulate::conda_create('r-reticulate', packages = c('python==3.11'))
- tensorflow::install_tensorflow(version='2.16')
+ install.packages(c("keras3"))
+ keras3::install_keras()
shell: Rscript {0}
- name: Test coverage
diff --git a/DESCRIPTION b/DESCRIPTION
index 74f62fa96..0e2119f76 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -1,8 +1,9 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
-Version: 1.3.3.9001
+Version: 1.4.0
Authors@R: c(
- person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
+ person("Max", "Kuhn", , "max@posit.co", role = c("cre", "aut"),
+ comment = c(ORCID = "0000-0003-2402-136X")),
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),
person("Emil", "Hvitfeldt", , "emil.hvitfeldt@posit.co", role = "ctb"),
person("Posit Software, PBC", role = c("cph", "fnd"),
diff --git a/NEWS.md b/NEWS.md
index d52d6950b..b88d5eee0 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -1,4 +1,4 @@
-# parsnip (development version)
+# parsnip 1.4.0
* Fixes issue with running predictions for Decision Trees in Spark (#1309)
diff --git a/data/model_db.rda b/data/model_db.rda
index 8168b7e22..784b7dc06 100644
Binary files a/data/model_db.rda and b/data/model_db.rda differ
diff --git a/inst/add-in/parsnip_model_db.R b/inst/add-in/parsnip_model_db.R
index b3cf38b0b..c5bec0d66 100644
--- a/inst/add-in/parsnip_model_db.R
+++ b/inst/add-in/parsnip_model_db.R
@@ -9,28 +9,66 @@ library(usethis)
# also requires installation of:
packages <- c(
"parsnip",
- "discrim",
- "plsmod",
- "rules",
- "baguette",
- "poissonreg",
- "multilevelmod",
- "modeltime",
- "modeltime.gluonts"
+ parsnip:::extensions(),
+ "modeltime"
+ # "modeltime.gluonts" # required python packages to create spec
)
+loaded <- map(packages, library, character.only = TRUE)
+
# ------------------------------------------------------------------------------
-# Detects model specifications via their print methods
-print_methods <- function(x) {
- require(x, character.only = TRUE)
- ns <- asNamespace(ns = x)
- mthds <- ls(envir = ns, pattern = "^print\\.")
- mthds <- gsub("^print\\.", "", mthds)
- purrr::map(mthds, get_engines) |>
+get_model <- function(x) {
+ res <- get_from_env(x)
+ if (!is.null(res)) {
+ res <- dplyr::mutate(res, model = x)
+ }
+ res
+}
+
+get_packages <- function(x) {
+ res <- get_from_env(paste0(x, "_pkgs"))
+ if (is.null(res)) {
+ return(res)
+ }
+ res <-
+ res |>
+ tidyr::unnest(pkg) |>
+ dplyr::mutate(
+ model = x
+ )
+
+ res
+}
+
+get_models <- function() {
+ res <- ls(envir = get_model_env(), pattern = "_fit$")
+ models <- gsub("_fit$", "", res)
+ models <-
+ purrr::map(models, get_model) |>
+ purrr::list_rbind()
+
+ # get source package
+ pkgs <- gsub("_fit$", "_pkgs", res)
+ pkgs <-
+ unique(models$model) |>
+ purrr::map(get_packages) |>
purrr::list_rbind() |>
- dplyr::mutate(package = x)
+ dplyr::filter(pkg %in% packages)
+ dplyr::left_join(models, pkgs, by = dplyr::join_by(engine, mode, model)) |>
+ dplyr::rename(package = pkg) |>
+ dplyr::mutate(
+ package = dplyr::if_else(is.na(package), "parsnip", package),
+ call_from_parsnip = package %in% parsnip:::extensions(),
+ caller_package = dplyr::if_else(
+ call_from_parsnip,
+ "parsnip",
+ package
+ )
+ )
}
+
+
get_engines <- function(x) {
eng <- try(parsnip::show_engines(x), silent = TRUE)
if (inherits(eng, "try-error")) {
@@ -77,8 +115,8 @@ get_tunable_param <- function(mode, package, model, engine) {
# ------------------------------------------------------------------------------
model_db <-
- purrr::map(packages, print_methods) |>
- purrr::list_rbind() |>
+ get_models() |>
+ dplyr::filter(mode %in% c("regression", "classification")) |>
dplyr::filter(engine != "liquidSVM") |>
dplyr::filter(model != "surv_reg") |>
dplyr::filter(engine != "spark") |>
@@ -98,9 +136,10 @@ model_db <-
dplyr::left_join(model_db, num_modes, by = c("package", "model", "engine")) |>
dplyr::mutate(
parameters = purrr::pmap(
- list(mode, package, model, engine),
+ list(mode, caller_package, model, engine),
get_tunable_param
)
- )
+ ) |>
+ dplyr::select(-call_from_parsnip, -caller_package)
usethis::use_data(model_db, overwrite = TRUE)
diff --git a/inst/models.tsv b/inst/models.tsv
index 2c39fff0c..a1ab034e7 100644
--- a/inst/models.tsv
+++ b/inst/models.tsv
@@ -14,11 +14,13 @@
"bart" "regression" "dbarts" NA
"boost_tree" "censored regression" "mboost" "censored"
"boost_tree" "classification" "C5.0" NA
+"boost_tree" "classification" "catboost" "bonsai"
"boost_tree" "classification" "h2o" "agua"
"boost_tree" "classification" "h2o_gbm" "agua"
"boost_tree" "classification" "lightgbm" "bonsai"
"boost_tree" "classification" "spark" NA
"boost_tree" "classification" "xgboost" NA
+"boost_tree" "regression" "catboost" "bonsai"
"boost_tree" "regression" "h2o" "agua"
"boost_tree" "regression" "h2o_gbm" "agua"
"boost_tree" "regression" "lightgbm" "bonsai"
diff --git a/man/details_bart_dbarts.Rd b/man/details_bart_dbarts.Rd
index 352bc0467..49a2c5685 100644
--- a/man/details_bart_dbarts.Rd
+++ b/man/details_bart_dbarts.Rd
@@ -124,7 +124,8 @@ indicators if the user does not create them first.
\subsection{Prediction types}{
\if{html}{\out{
}}\preformatted{parsnip:::get_from_env("bart_predict") |>
- dplyr::select(mode, type)
+ dplyr::select(mode, type) |>
+ print(n = Inf)
}\if{html}{\out{
}}
\if{html}{\out{}}\preformatted{## # A tibble: 9 x 2
@@ -136,7 +137,9 @@ indicators if the user does not create them first.
## 4 regression pred_int
## 5 classification class
## 6 classification prob
-## # i 3 more rows
+## 7 classification conf_int
+## 8 classification pred_int
+## 9 classification raw
}\if{html}{\out{
}}
}
diff --git a/man/details_boost_tree_h2o.Rd b/man/details_boost_tree_h2o.Rd
index 94ad47e74..801067b33 100644
--- a/man/details_boost_tree_h2o.Rd
+++ b/man/details_boost_tree_h2o.Rd
@@ -160,7 +160,8 @@ within \verb{[0, 1]}.
\if{html}{\out{}}\preformatted{parsnip:::get_from_env("boost_tree_predict") |>
dplyr::filter(stringr::str_starts(engine, "h2o")) |>
- dplyr::select(mode, type)
+ dplyr::select(mode, type) |>
+ print(n = Inf)
}\if{html}{\out{
}}
\if{html}{\out{}}\preformatted{## # A tibble: 8 x 2
@@ -172,7 +173,8 @@ within \verb{[0, 1]}.
## 4 classification prob
## 5 regression numeric
## 6 regression raw
-## # i 2 more rows
+## 7 classification class
+## 8 classification prob
}\if{html}{\out{
}}
}
diff --git a/man/details_boost_tree_xgboost.Rd b/man/details_boost_tree_xgboost.Rd
index dca1b8843..f68aa83f4 100644
--- a/man/details_boost_tree_xgboost.Rd
+++ b/man/details_boost_tree_xgboost.Rd
@@ -9,7 +9,10 @@ ensemble. Each tree depends on the results of previous trees. All trees in
the ensemble are combined to produce a final prediction.
}
\details{
-For this engine, there are multiple modes: classification and regression
+For this engine, there are multiple modes: classification and
+regression. Note that in late 2025, a new version of xgboost was
+released with differences in its interface and model objects. This
+version of parsnip should work with either version.
\subsection{Tuning Parameters}{
This model has 8 tuning parameters:
diff --git a/man/details_proportional_hazards_glmnet.Rd b/man/details_proportional_hazards_glmnet.Rd
index 3322a7c0c..7dcd2ff1a 100644
--- a/man/details_proportional_hazards_glmnet.Rd
+++ b/man/details_proportional_hazards_glmnet.Rd
@@ -154,6 +154,23 @@ predictions. When saving the model for the purpose of prediction, the
size of the saved object might be substantially reduced by using
functions from the \href{https://butcher.tidymodels.org}{butcher} package.
}
+
+\subsection{Prediction types}{
+
+\if{html}{\out{}}\preformatted{parsnip:::get_from_env("proportional_hazards_predict") |>
+ dplyr::filter(engine == "glmnet") |>
+ dplyr::select(mode, type)
+}\if{html}{\out{
}}
+
+\if{html}{\out{}}\preformatted{## # A tibble: 4 x 2
+## mode type
+##
+## 1 censored regression linear_pred
+## 2 censored regression survival
+## 3 censored regression time
+## 4 censored regression raw
+}\if{html}{\out{
}}
+}
}
\section{References}{
\itemize{
diff --git a/man/details_rand_forest_aorsf.Rd b/man/details_rand_forest_aorsf.Rd
index 6ccb4d7d6..8b2124cdc 100644
--- a/man/details_rand_forest_aorsf.Rd
+++ b/man/details_rand_forest_aorsf.Rd
@@ -119,7 +119,8 @@ The \code{fit()} and \code{fit_xy()} arguments have arguments called
\if{html}{\out{}}\preformatted{parsnip:::get_from_env("rand_forest_predict") |>
dplyr::filter(engine == "aorsf") |>
- dplyr::select(mode, type)
+ dplyr::select(mode, type)|>
+ print(n = Inf)
}\if{html}{\out{
}}
\if{html}{\out{}}\preformatted{## # A tibble: 7 x 2
@@ -131,7 +132,7 @@ The \code{fit()} and \code{fit_xy()} arguments have arguments called
## 4 classification prob
## 5 classification raw
## 6 regression numeric
-## # i 1 more row
+## 7 regression raw
}\if{html}{\out{
}}
}
diff --git a/man/details_svm_rbf_kernlab.Rd b/man/details_svm_rbf_kernlab.Rd
index 897538684..5b9946557 100644
--- a/man/details_svm_rbf_kernlab.Rd
+++ b/man/details_svm_rbf_kernlab.Rd
@@ -111,19 +111,23 @@ The underlying model implementation does not allow for case weights.
\subsection{Prediction types}{
\if{html}{\out{}}\preformatted{parsnip:::get_from_env("svm_rbf_predict") |>
- dplyr::select(mode, type)
+ dplyr::select(mode, type) |>
+ print(n = Inf)
}\if{html}{\out{
}}
\if{html}{\out{}}\preformatted{## # A tibble: 10 x 2
-## mode type
-##
-## 1 regression numeric
-## 2 regression raw
-## 3 classification class
-## 4 classification prob
-## 5 classification raw
-## 6 regression numeric
-## # i 4 more rows
+## mode type
+##
+## 1 regression numeric
+## 2 regression raw
+## 3 classification class
+## 4 classification prob
+## 5 classification raw
+## 6 regression numeric
+## 7 regression raw
+## 8 classification class
+## 9 classification prob
+## 10 classification raw
}\if{html}{\out{
}}
}
diff --git a/man/rmd/decision_tree_partykit.md b/man/rmd/decision_tree_partykit.md
index 9d712e4e2..2cce9c12d 100644
--- a/man/rmd/decision_tree_partykit.md
+++ b/man/rmd/decision_tree_partykit.md
@@ -1,7 +1,7 @@
-For this engine, there are multiple modes: regression, classification, and censored regression
+For this engine, there are multiple modes: censored regression, regression, and classification
## Tuning Parameters
@@ -9,10 +9,10 @@ For this engine, there are multiple modes: regression, classification, and censo
This model has 2 tuning parameters:
-- `min_n`: Minimal Node Size (type: integer, default: 20L)
-
- `tree_depth`: Tree Depth (type: integer, default: see below)
+- `min_n`: Minimal Node Size (type: integer, default: 20L)
+
The `tree_depth` parameter defaults to `0` which means no restrictions are applied to tree depth.
An engine-specific parameter for this model is:
@@ -128,11 +128,11 @@ parsnip:::get_from_env("decision_tree_predict") |>
## # A tibble: 5 x 2
## mode type
##
-## 1 regression numeric
-## 2 classification class
-## 3 classification prob
-## 4 censored regression time
-## 5 censored regression survival
+## 1 censored regression time
+## 2 censored regression survival
+## 3 regression numeric
+## 4 classification class
+## 5 classification prob
```
## Other details
diff --git a/man/rmd/rand_forest_aorsf.Rmd b/man/rmd/rand_forest_aorsf.Rmd
index 349a31f68..42f5a0276 100644
--- a/man/rmd/rand_forest_aorsf.Rmd
+++ b/man/rmd/rand_forest_aorsf.Rmd
@@ -96,7 +96,8 @@ rand_forest() |>
parsnip:::get_from_env("rand_forest_predict") |>
dplyr::filter(engine == "aorsf") |>
- dplyr::select(mode, type)
+ dplyr::select(mode, type)|>
+ print(n = Inf)
```
diff --git a/man/rmd/rand_forest_aorsf.md b/man/rmd/rand_forest_aorsf.md
index 2c8a53b4f..7773a2d55 100644
--- a/man/rmd/rand_forest_aorsf.md
+++ b/man/rmd/rand_forest_aorsf.md
@@ -1,7 +1,7 @@
-For this engine, there are multiple modes: classification, regression, and censored regression
+For this engine, there are multiple modes: censored regression, classification, and regression
## Tuning Parameters
@@ -9,12 +9,12 @@ For this engine, there are multiple modes: classification, regression, and censo
This model has 3 tuning parameters:
-- `mtry`: # Randomly Selected Predictors (type: integer, default: ceiling(sqrt(n_predictors)))
-
- `trees`: # Trees (type: integer, default: 500L)
- `min_n`: Minimal Node Size (type: integer, default: 5L)
+- `mtry`: # Randomly Selected Predictors (type: integer, default: ceiling(sqrt(n_predictors)))
+
Additionally, this model has one engine-specific tuning parameter:
* `split_min_stat`: Minimum test statistic required to split a node. Defaults are `3.841459` for censored regression (which is roughly a p-value of 0.05) and `0` for classification and regression. For classification, this tuning parameter should be between 0 and 1, and for regression it should be greater than or equal to 0. Higher values of this parameter cause trees grown by `aorsf` to have less depth.
@@ -108,20 +108,21 @@ The `fit()` and `fit_xy()` arguments have arguments called `case_weights` that e
``` r
parsnip:::get_from_env("rand_forest_predict") |>
dplyr::filter(engine == "aorsf") |>
- dplyr::select(mode, type)
+ dplyr::select(mode, type)|>
+ print(n = Inf)
```
```
## # A tibble: 7 x 2
-## mode type
-##
-## 1 classification class
-## 2 classification prob
-## 3 classification raw
-## 4 regression numeric
-## 5 regression raw
-## 6 censored regression time
-## # i 1 more row
+## mode type
+##
+## 1 censored regression time
+## 2 censored regression survival
+## 3 classification class
+## 4 classification prob
+## 5 classification raw
+## 6 regression numeric
+## 7 regression raw
```
## Other details
diff --git a/man/rmd/rand_forest_partykit.md b/man/rmd/rand_forest_partykit.md
index a5e76e1b8..35b9db9be 100644
--- a/man/rmd/rand_forest_partykit.md
+++ b/man/rmd/rand_forest_partykit.md
@@ -1,7 +1,7 @@
-For this engine, there are multiple modes: regression, classification, and censored regression
+For this engine, there are multiple modes: censored regression, regression, and classification
## Tuning Parameters
@@ -9,12 +9,12 @@ For this engine, there are multiple modes: regression, classification, and censo
This model has 3 tuning parameters:
+- `trees`: # Trees (type: integer, default: 500L)
+
- `min_n`: Minimal Node Size (type: integer, default: 20L)
- `mtry`: # Randomly Selected Predictors (type: integer, default: 5L)
-- `trees`: # Trees (type: integer, default: 500L)
-
## Translation from parsnip to the original package (regression)
The **bonsai** extension package is required to fit this model.
@@ -110,11 +110,11 @@ parsnip:::get_from_env("rand_forest_predict") |>
## # A tibble: 5 x 2
## mode type
##
-## 1 regression numeric
-## 2 classification class
-## 3 classification prob
-## 4 censored regression time
-## 5 censored regression survival
+## 1 censored regression time
+## 2 censored regression survival
+## 3 regression numeric
+## 4 classification class
+## 5 classification prob
```
## Other details
diff --git a/tests/testthat/test-fit_interfaces.R b/tests/testthat/test-fit_interfaces.R
index 440ab9a12..31b084582 100644
--- a/tests/testthat/test-fit_interfaces.R
+++ b/tests/testthat/test-fit_interfaces.R
@@ -172,7 +172,7 @@ test_that("overhead of parsnip interface is minimal (#1071)", {
skip_on_cran()
skip_on_covr()
skip_if_not_installed("bench")
- skip_if_not_installed("parsnip", minimum_version = "1.4.0")
+ skip_if_not_installed("parsnip", minimum_version = "1.5.0")
bm <- bench::mark(
time_engine = lm(mpg ~ ., mtcars),
diff --git a/vignettes/parsnip.Rmd b/vignettes/parsnip.Rmd
index a2474bceb..3718eec1a 100644
--- a/vignettes/parsnip.Rmd
+++ b/vignettes/parsnip.Rmd
@@ -159,7 +159,7 @@ rf_with_seed |>
Note that the call objects show `num.trees = ~2000`. The tilde is the consequence of `parsnip` using [quosures](https://adv-r.hadley.nz/evaluation.html#quosures) to process the model specification's arguments.
-Normally, when a function is executed, the function's arguments are immediately evaluated. In the case of parsnip, the model specification's arguments are _not_; the [expression is captured](https://www.tidyverse.org/blog/2019/04/parsnip-internals/) along with the environment where it should be evaluated. That is what a quosure does.
+Normally, when a function is executed, the function's arguments are immediately evaluated. In the case of parsnip, the model specification's arguments are _not_; the [expression is captured](https://tidyverse.org/blog/2019/04/parsnip-internals/) along with the environment where it should be evaluated. That is what a quosure does.
parsnip uses these expressions to make a model fit call that is evaluated. The tilde in the call above reflects that the argument was captured using a quosure.