Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ Suggests:
testthat (>= 3.0.0),
torch,
TH.data,
tune,
usethis (>= 1.5.0),
workflowsets,
xgboost (>= 1.3.2.1),
xrf
VignetteBuilder:
Expand Down
16 changes: 15 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ S3method(axe_env,terms)
S3method(axe_env,train)
S3method(axe_env,train.recipe)
S3method(axe_env,xgb.Booster)
S3method(axe_fitted,"_tabnet_fit")
S3method(axe_env,xrf)
S3method(axe_fitted,"_tabnet_fit")
S3method(axe_fitted,C5.0)
S3method(axe_fitted,KMeansCluster)
S3method(axe_fitted,bart)
Expand All @@ -139,6 +139,18 @@ S3method(axe_fitted,ranger)
S3method(axe_fitted,recipe)
S3method(axe_fitted,train)
S3method(axe_fitted,train.recipe)
S3method(axe_rsample_data,default)
S3method(axe_rsample_data,rset)
S3method(axe_rsample_data,rsplit)
S3method(axe_rsample_data,three_way_split)
S3method(axe_rsample_data,tune_results)
S3method(axe_rsample_data,workflow_set)
S3method(axe_rsample_indicators,default)
S3method(axe_rsample_indicators,rset)
S3method(axe_rsample_indicators,rsplit)
S3method(axe_rsample_indicators,three_way_split)
S3method(axe_rsample_indicators,tune_results)
S3method(axe_rsample_indicators,workflow_set)
S3method(weigh,default)
S3method(weigh,ksvm)
S3method(weigh,model_fit)
Expand All @@ -147,6 +159,8 @@ export(axe_ctrl)
export(axe_data)
export(axe_env)
export(axe_fitted)
export(axe_rsample_data)
export(axe_rsample_indicators)
export(butcher)
export(locate)
export(new_model_butcher)
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

* Added methods for `MASS::polr` (@pbulsink, #289).

* Added methods for `rset`, `rsplit`, `tune_results`, and `workflow_set` classes (#292).

* Make to work with new versions of xgboost models (#294).

* Added butcher methods for `tabnet()` (@cregouby #226).

* * Added methods for `MASS::polr` (@pbulsink, #289).

# butcher 0.3.6

Expand Down
82 changes: 81 additions & 1 deletion R/axe-generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ butcher <- function(x, verbose = FALSE, ...) {
x <- axe_data(x, verbose = FALSE, ...)
x <- axe_env(x, verbose = FALSE, ...)
x <- axe_fitted(x, verbose = FALSE, ...)
x <- axe_rsample_data(x, verbose = FALSE, ...)
x <- axe_rsample_indicators(x, verbose = FALSE, ...)

add_butcher_attributes(
x,
Expand Down Expand Up @@ -106,7 +108,6 @@ axe_data.default <- function(x, verbose = FALSE, ...) {
x
}


#' Axe an environment.
#'
#' Remove the environment(s) attached to modeling objects as they are
Expand Down Expand Up @@ -159,3 +160,82 @@ axe_fitted.default <- function(x, verbose = FALSE, ...) {
x
}

#' Axe data within rsample objects.
#'
#' Replace the splitting and resampling objects with a placeholder.
#'
#' Resampling and splitting objects produced by \pkg{rsample} contain `rsplit`
#' objects. These contain the original data set. These data might be large so
#' we sometimes wish to remove them when saving objects. This method creates a
#' zero-row slice of the dataset, retaining only the column names and their
#' attributes, while replacing the original data.
#'
#' @name axe-rsample-data
#' @inheritParams butcher
#' @param x An object.
#'
#' @return An updated object without data in the `rsplit` objects.
#'
#' @section Methods:
#' \Sexpr[stage=render,results=rd]{butcher:::methods_rd("axe_rsample_data")}
#'
#' @examplesIf rlang::is_installed("rsample")
#'
#' large_cars <- mtcars[rep(1:32, 50), ]
#' large_cars_split <- rsample::initial_split(large_cars)
#' butcher(large_cars_split, verbose = TRUE)
#'
#' @export
axe_rsample_data <- function(x, verbose = FALSE, ...) {
UseMethod("axe_rsample_data")
}

#' @export
#' @name axe-rsample-data
axe_rsample_data.default <- function(x, verbose = FALSE, ...) {
old <- x
if (verbose) {
assess_object(old, x)
}
x
}

#' Axe indicators within rsample objects.
#'
#' Replace the splitting and resampling objects with a placeholder.
#'
#' Resampling and splitting objects produced by \pkg{rsample} contain `rsplit`
#' objects. These contain the original data set as well as indicators that
#' specify which rows go into which data partitions. These size of these
#' integers might be large so we sometimes wish to remove them when saving
#' objects. This method saves a zero-row integer in their place.
#'
#' @name axe-rsample-indicators
#' @inheritParams butcher
#' @param x An object.
#'
#' @return An updated object without the indicators in the `rsplit` objects.
#'
#' @section Methods:
#' \Sexpr[stage=render,results=rd]{butcher:::methods_rd("axe_rsample_indicators")}
#'
#' @examplesIf rlang::is_installed("rsample")
#'
#' large_cars <- mtcars[rep(1:32, 50), ]
#' large_cars_split <- rsample::initial_split(large_cars)
#' butcher(large_cars_split, verbose = TRUE)
#'
#' @export
axe_rsample_indicators <- function(x, verbose = FALSE, ...) {
UseMethod("axe_rsample_indicators")
}

#' @export
#' @name axe-rsample-indicators
axe_rsample_indicators.default <- function(x, verbose = FALSE, ...) {
old <- x
if (verbose) {
assess_object(old, x)
}
x
}
104 changes: 104 additions & 0 deletions R/rsample-data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#' @rdname axe-rsample-data
#' @export
axe_rsample_data.rsplit <- function(x, verbose = FALSE, ...) {
old <- x
x <- zero_data(x)
add_butcher_attributes(
x,
old,
disabled = c(
"analysis()",
"as.data.frame()",
"as.integer()",
"assessment()",
"complement()",
"internal_calibration_split()",
"populate()",
"reverse_splits()",
"testing()",
"tidy()",
"training()"
),
verbose = verbose
)
}

#' @rdname axe-rsample-data
#' @export
axe_rsample_data.three_way_split <- function(x, verbose = FALSE, ...) {
old <- x
x <- zero_data(x)
add_butcher_attributes(
x,
old,
disabled = c(
"internal_calibration_split()",
"testing()",
"training()",
"validation()"
),
verbose = verbose
)
}

#' @rdname axe-rsample-data
#' @export
axe_rsample_data.rset <- function(x, verbose = FALSE, ...) {
old <- x

if (any(names(x) == "splits")) {
x$splits <- purrr::map(x$splits, axe_rsample_data)
}

add_butcher_attributes(
x,
old,
disabled = c("populate()", "reverse_splits()", "tidy()"),
verbose = verbose
)
}

#' @rdname axe-rsample-data
#' @export
axe_rsample_data.tune_results <- function(x, verbose = FALSE, ...) {
old <- x

if (any(names(x) == "splits")) {
x$splits <- purrr::map(x$splits, axe_rsample_data)
}

add_butcher_attributes(
x,
old,
disabled = c("augment()", "fit_best()"),
verbose = verbose
)
}

#' @rdname axe-rsample-data
#' @export
axe_rsample_data.workflow_set <- function(x, verbose = FALSE, ...) {
has_res <- purrr::map_lgl(x$result, ~ inherits(.x, "tune_results"))
if (!any(has_res)) {
return(x)
}
old <- x

for (i in which(has_res)) {
x$result[[i]] <- axe_rsample_data(x$result[[i]])
}

add_butcher_attributes(
x,
old,
disabled = c("augment()", "fit_best()"),
verbose = verbose
)
}

# ------------------------------------------------------------------------------

zero_data <- function(split) {
split$data <- split$data[0, , drop = FALSE]
split
}
115 changes: 115 additions & 0 deletions R/rsample-indicators.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#' @rdname axe-rsample-indicators
#' @export
axe_rsample_indicators.rsplit <- function(x, verbose = FALSE, ...) {
old <- x
x <- zero_all_ind(x)
add_butcher_attributes(
x,
old,
disabled = c(
"analysis()",
"as.data.frame()",
"as.integer()",
"assessment()",
"complement()",
"internal_calibration_split()",
"populate()",
"reverse_splits()",
"testing()",
"tidy()",
"training()"
),
verbose = verbose
)
}

#' @rdname axe-rsample-indicators
#' @export
axe_rsample_indicators.three_way_split <- function(x, verbose = FALSE, ...) {
old <- x
x <- zero_all_ind(x)
add_butcher_attributes(
x,
old,
disabled = c(
"internal_calibration_split()",
"testing()",
"training()",
"validation()"
),
verbose = verbose
)
}

#' @rdname axe-rsample-indicators
#' @export
axe_rsample_indicators.rset <- function(x, verbose = FALSE, ...) {
old <- x

if (any(names(x) == "splits")) {
x$splits <- purrr::map(x$splits, axe_rsample_indicators)
}

add_butcher_attributes(
x,
old,
disabled = c("populate()", "reverse_splits()", "tidy()"),
verbose = verbose
)
}


#' @rdname axe-rsample-indicators
#' @export
axe_rsample_indicators.tune_results <- function(x, verbose = FALSE, ...) {
old <- x
if (any(names(x) == "splits")) {
x$splits <- purrr::map(x$splits, axe_rsample_indicators)
}
add_butcher_attributes(
x,
old,
disabled = c("augment()", "fit_best()"),
verbose = verbose
)
}


#' @rdname axe-rsample-indicators
#' @export
axe_rsample_indicators.workflow_set <- function(x, verbose = FALSE, ...) {
has_res <- purrr::map_lgl(x$result, ~ inherits(.x, "tune_results"))
if (!any(has_res)) {
return(x)
}

old <- x

for (i in which(has_res)) {
x$result[[i]] <- axe_rsample_indicators(x$result[[i]])
}

add_butcher_attributes(
x,
old,
disabled = c("augment()", "fit_best()"),
verbose = verbose
)
}

# ------------------------------------------------------------------------------

zero_ind <- function(ind) {
if (!all(is.na(ind))) {
ind <- integer(0)
}
ind
}

zero_all_ind <- function(split) {
ind_nms <- grep("_id$", names(split), value = TRUE)
for (nm in ind_nms) {
split[[nm]] <- zero_ind(split[[nm]])
}
split
}
Binary file added inst/extdata/workflow_sets.RData
Binary file not shown.
Loading