Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ jobs:
with:
python-version: 3.11

- name: Install TensorFlow
- name: Install TensorFlow/Keras
run: |
reticulate::virtualenv_create('r-reticulate', python='3.11')
reticulate::use_virtualenv('r-reticulate')
tensorflow::install_tensorflow(version='2.16')
install.packages(c("keras3"))
keras3::install_keras()
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2
Expand Down
15 changes: 3 additions & 12 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,10 @@ jobs:
extra-packages: any::covr, any::xml2
needs: coverage

- name: Install dev reticulate
run: pak::pkg_install('rstudio/reticulate')
shell: Rscript {0}

- name: Install Miniconda
run: |
reticulate::install_miniconda()
shell: Rscript {0}

- name: Install TensorFlow
- name: Install TensorFlow/Keras
run: |
reticulate::conda_create('r-reticulate', packages = c('python==3.11'))
tensorflow::install_tensorflow(version='2.16')
install.packages(c("keras3"))
keras3::install_keras()
shell: Rscript {0}

- name: Test coverage
Expand Down
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.3.3.9001
Version: 1.4.0
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Max", "Kuhn", , "[email protected]", role = c("cre", "aut"),
comment = c(ORCID = "0000-0003-2402-136X")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
person("Emil", "Hvitfeldt", , "[email protected]", role = "ctb"),
person("Posit Software, PBC", role = c("cph", "fnd"),
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# parsnip (development version)
# parsnip 1.4.0

* Fixes issue with running predictions for Decision Trees in Spark (#1309)

Expand Down
Binary file modified data/model_db.rda
Binary file not shown.
79 changes: 59 additions & 20 deletions inst/add-in/parsnip_model_db.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,66 @@ library(usethis)
# also requires installation of:
packages <- c(
"parsnip",
"discrim",
"plsmod",
"rules",
"baguette",
"poissonreg",
"multilevelmod",
"modeltime",
"modeltime.gluonts"
parsnip:::extensions(),
"modeltime"
# "modeltime.gluonts" # required python packages to create spec
)

loaded <- map(packages, library, character.only = TRUE)

# ------------------------------------------------------------------------------

# Detects model specifications via their print methods
print_methods <- function(x) {
require(x, character.only = TRUE)
ns <- asNamespace(ns = x)
mthds <- ls(envir = ns, pattern = "^print\\.")
mthds <- gsub("^print\\.", "", mthds)
purrr::map(mthds, get_engines) |>
get_model <- function(x) {
res <- get_from_env(x)
if (!is.null(res)) {
res <- dplyr::mutate(res, model = x)
}
res
}

get_packages <- function(x) {
res <- get_from_env(paste0(x, "_pkgs"))
if (is.null(res)) {
return(res)
}
res <-
res |>
tidyr::unnest(pkg) |>
dplyr::mutate(
model = x
)

res
}

get_models <- function() {
res <- ls(envir = get_model_env(), pattern = "_fit$")
models <- gsub("_fit$", "", res)
models <-
purrr::map(models, get_model) |>
purrr::list_rbind()

# get source package
pkgs <- gsub("_fit$", "_pkgs", res)
pkgs <-
unique(models$model) |>
purrr::map(get_packages) |>
purrr::list_rbind() |>
dplyr::mutate(package = x)
dplyr::filter(pkg %in% packages)
dplyr::left_join(models, pkgs, by = dplyr::join_by(engine, mode, model)) |>
dplyr::rename(package = pkg) |>
dplyr::mutate(
package = dplyr::if_else(is.na(package), "parsnip", package),
call_from_parsnip = package %in% parsnip:::extensions(),
caller_package = dplyr::if_else(
call_from_parsnip,
"parsnip",
package
)
)
}


get_engines <- function(x) {
eng <- try(parsnip::show_engines(x), silent = TRUE)
if (inherits(eng, "try-error")) {
Expand Down Expand Up @@ -77,8 +115,8 @@ get_tunable_param <- function(mode, package, model, engine) {
# ------------------------------------------------------------------------------

model_db <-
purrr::map(packages, print_methods) |>
purrr::list_rbind() |>
get_models() |>
dplyr::filter(mode %in% c("regression", "classification")) |>
dplyr::filter(engine != "liquidSVM") |>
dplyr::filter(model != "surv_reg") |>
dplyr::filter(engine != "spark") |>
Expand All @@ -98,9 +136,10 @@ model_db <-
dplyr::left_join(model_db, num_modes, by = c("package", "model", "engine")) |>
dplyr::mutate(
parameters = purrr::pmap(
list(mode, package, model, engine),
list(mode, caller_package, model, engine),
get_tunable_param
)
)
) |>
dplyr::select(-call_from_parsnip, -caller_package)

usethis::use_data(model_db, overwrite = TRUE)
2 changes: 2 additions & 0 deletions inst/models.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
"bart" "regression" "dbarts" NA
"boost_tree" "censored regression" "mboost" "censored"
"boost_tree" "classification" "C5.0" NA
"boost_tree" "classification" "catboost" "bonsai"
"boost_tree" "classification" "h2o" "agua"
"boost_tree" "classification" "h2o_gbm" "agua"
"boost_tree" "classification" "lightgbm" "bonsai"
"boost_tree" "classification" "spark" NA
"boost_tree" "classification" "xgboost" NA
"boost_tree" "regression" "catboost" "bonsai"
"boost_tree" "regression" "h2o" "agua"
"boost_tree" "regression" "h2o_gbm" "agua"
"boost_tree" "regression" "lightgbm" "bonsai"
Expand Down
7 changes: 5 additions & 2 deletions man/details_bart_dbarts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions man/details_boost_tree_h2o.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/details_boost_tree_xgboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions man/details_proportional_hazards_glmnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions man/details_rand_forest_aorsf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 14 additions & 10 deletions man/details_svm_rbf_kernlab.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions man/rmd/decision_tree_partykit.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@



For this engine, there are multiple modes: regression, classification, and censored regression
For this engine, there are multiple modes: censored regression, regression, and classification

## Tuning Parameters



This model has 2 tuning parameters:

- `min_n`: Minimal Node Size (type: integer, default: 20L)

- `tree_depth`: Tree Depth (type: integer, default: see below)

- `min_n`: Minimal Node Size (type: integer, default: 20L)

The `tree_depth` parameter defaults to `0` which means no restrictions are applied to tree depth.

An engine-specific parameter for this model is:
Expand Down Expand Up @@ -128,11 +128,11 @@ parsnip:::get_from_env("decision_tree_predict") |>
## # A tibble: 5 x 2
## mode type
## <chr> <chr>
## 1 regression numeric
## 2 classification class
## 3 classification prob
## 4 censored regression time
## 5 censored regression survival
## 1 censored regression time
## 2 censored regression survival
## 3 regression numeric
## 4 classification class
## 5 classification prob
```

## Other details
Expand Down
3 changes: 2 additions & 1 deletion man/rmd/rand_forest_aorsf.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ rand_forest() |>

parsnip:::get_from_env("rand_forest_predict") |>
dplyr::filter(engine == "aorsf") |>
dplyr::select(mode, type)
dplyr::select(mode, type)|>
print(n = Inf)

```

Expand Down
Loading