Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# butcher (development version)

* Added methods for `MASS::polr` (@pbulsink, #289).

* Make to work with new versions of xgboost models (#294).

* Added butcher methods for `tabnet()` (@cregouby #226).

* * Added methods for `MASS::polr` (@pbulsink, #289).
Expand Down
20 changes: 14 additions & 6 deletions R/xgb.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ NULL
#' @export
axe_call.xgb.Booster <- function(x, verbose = FALSE, ...) {
old <- x
x <- exchange(x, "call", call("dummy_call"))
if (is.null(attr(x, "call"))) {
x <- exchange(x, "call", call("dummy_call"))
} else {
attr(x, "call") <- call("dummy_call")
}

add_butcher_attributes(
x,
Expand All @@ -64,11 +68,15 @@ axe_call.xgb.Booster <- function(x, verbose = FALSE, ...) {
#' @export
axe_env.xgb.Booster <- function(x, verbose = FALSE, ...) {
old <- x
x$callbacks <- purrr::map(x$callbacks,
function(x)
as.function(c(formals(x), body(x)), env = rlang::base_env())
)

if (!is.null(x$callbacks)) {
x$callbacks <-
purrr::map(
x$callbacks,
function(x) {
as.function(c(formals(x), body(x)), env = rlang::base_env())
}
)
}
add_butcher_attributes(
x,
old,
Expand Down
104 changes: 71 additions & 33 deletions tests/testthat/test-xgb.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,42 @@ test_that("xgb.Booster + linear solver + predict() works", {
# Load data
data(agaricus.train)
data(agaricus.test)
bst <- xgboost(data = agaricus.train$data,
label = agaricus.train$label,
eta = 1,
nthread = 2,
nrounds = 2,
eval_metric = "logloss",
objective = "binary:logistic",
verbose = 0)
if (utils::packageVersion("xgboost") > "2.0.0.0") {
bst <- xgboost(
x = agaricus.train$data,
y = agaricus.train$label,
learning_rate = 1,
nthread = 2,
nrounds = 2,
eval_metric = "logloss",
objective = "reg:squarederror"
)
} else {
bst <- xgboost(
data = agaricus.train$data,
label = agaricus.train$label,
eta = 1,
nthread = 2,
nrounds = 2,
eval_metric = "logloss",
objective = "binary:logistic",
verbose = 0
)
}
x <- axe_call(bst)
expect_equal(x$call, rlang::expr(dummy_call()))
if (utils::packageVersion("xgboost") > "2.0.0.0") {
extracted_call <- attr(x, "call")
} else {
extracted_call <- x$call
}
expect_equal(extracted_call, rlang::expr(dummy_call()))
x <- axe_env(bst)
expect_lt(lobstr::obj_size(x), lobstr::obj_size(bst))
expect_lte(lobstr::obj_size(x), lobstr::obj_size(bst))
expect_lte(lobstr::obj_size(attributes(x)), lobstr::obj_size(attributes(bst)))
x <- butcher(bst)
expect_equal(xgb.importance(model = x),
xgb.importance(model = bst))
expect_equal(predict(x, agaricus.test$data),
predict(bst, agaricus.test$data))
expect_equal(xgb.dump(x, with_stats = TRUE),
xgb.dump(bst, with_stats = TRUE))
expect_equal(xgb.importance(model = x), xgb.importance(model = bst))
expect_equal(predict(x, agaricus.test$data), predict(bst, agaricus.test$data))
expect_equal(xgb.dump(x, with_stats = TRUE), xgb.dump(bst, with_stats = TRUE))
})

test_that("xgb.Booster + tree-learning algo + predict() works", {
Expand All @@ -37,24 +54,45 @@ test_that("xgb.Booster + tree-learning algo + predict() works", {
# Load data
data(agaricus.train)
data(agaricus.test)
dtrain <- xgb.DMatrix(data = agaricus.train$data,
label = agaricus.train$label)
bst <- xgb.train(data = dtrain,
booster = "gblinear",
nthread = 2,
nrounds = 2,
eval_metric = "logloss",
objective = "binary:logistic",
print_every_n = 10000L)
dtrain <- xgb.DMatrix(
data = agaricus.train$data,
label = agaricus.train$label
)
if (utils::packageVersion("xgboost") > "2.0.0.0") {
bst <- xgb.train(
params = list(
booster = "gblinear",
nthread = 2,
eval_metric = "logloss",
objective = "binary:logistic",
print_every_n = 10000L
),
nrounds = 2,
data = dtrain
)
} else {
bst <- xgb.train(
data = dtrain,
booster = "gblinear",
nthread = 2,
nrounds = 2,
eval_metric = "logloss",
objective = "binary:logistic",
print_every_n = 10000L
)
}
x <- axe_call(bst)
expect_equal(x$call, rlang::expr(dummy_call()))
if (utils::packageVersion("xgboost") > "2.0.0.0") {
extracted_call <- attr(x, "call")
} else {
extracted_call <- x$call
}
expect_equal(extracted_call, rlang::expr(dummy_call()))
x <- axe_env(bst)
expect_lt(lobstr::obj_size(x), lobstr::obj_size(bst))
expect_lte(lobstr::obj_size(x), lobstr::obj_size(bst))
expect_lte(lobstr::obj_size(attributes(x)), lobstr::obj_size(attributes(bst)))
x <- butcher(bst)
expect_equal(xgb.importance(model = x),
xgb.importance(model = bst))
expect_equal(predict(x, agaricus.test$data),
predict(bst, agaricus.test$data))
expect_equal(xgb.dump(x, with_stats = TRUE),
xgb.dump(bst, with_stats = TRUE))
expect_equal(xgb.importance(model = x), xgb.importance(model = bst))
expect_equal(predict(x, agaricus.test$data), predict(bst, agaricus.test$data))
expect_equal(xgb.dump(x, with_stats = TRUE), xgb.dump(bst, with_stats = TRUE))
})
41 changes: 35 additions & 6 deletions tests/testthat/test-xrf.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@ test_that("xrf + axe_call() works", {
family = 'gaussian'
)
x <- axe_call(res)
expect_equal(x$xgb$call, rlang::expr(dummy_call()))

# due to new xgboost version
# https://github.com/tidymodels/butcher/issues/294
if (is.null(x$xgb$call)) {
expect_equal(attr(x$xgb, "call"), rlang::expr(dummy_call()))
} else {
expect_equal(x$xgb$call, rlang::expr(dummy_call()))
}

expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call()))
expect_equal(x$glm$model$call, rlang::expr(dummy_call()))
})
Expand All @@ -24,9 +32,16 @@ test_that("xrf + axe_env() works", {
)
x <- axe_env(res)
expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env())
expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env())
expect_equal(
attr(x$rule_augmented_formula, ".Environment"),
rlang::base_env()
)
expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env())
expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env())
# due to new xgboost version
# https://github.com/tidymodels/butcher/issues/294
if (!is.null(x$xgb$callbacks)) {
expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env())
}
})

test_that("xrf + butcher() works", {
Expand All @@ -38,13 +53,27 @@ test_that("xrf + butcher() works", {
family = 'gaussian'
)
x <- butcher(res)
expect_equal(x$xgb$call, rlang::expr(dummy_call()))
# due to new xgboost version
# https://github.com/tidymodels/butcher/issues/294
if (is.null(x$xgb$call)) {
expect_equal(attr(x$xgb, "call"), rlang::expr(dummy_call()))
} else {
expect_equal(x$xgb$call, rlang::expr(dummy_call()))
}

expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call()))
expect_equal(x$glm$model$call, rlang::expr(dummy_call()))
expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env())
expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env())
expect_equal(
attr(x$rule_augmented_formula, ".Environment"),
rlang::base_env()
)
expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env())
expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env())
# due to new xgboost version
# https://github.com/tidymodels/butcher/issues/294
if (!is.null(x$xgb$callbacks)) {
expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env())
}
})

test_that("xrf + predict() works", {
Expand Down