Skip to content

Commit 2f38399

Browse files
EmilHvitfeldttopepogithub-actions[bot]
authored
Make package work with both versions of xgboost models (#294)
* update xgb.Booster methods to work with newer versions of butcher * make test-xrf.R xgboost proof * add news * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * more formatting --------- Co-authored-by: ‘topepo’ <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 2f96a9b commit 2f38399

File tree

4 files changed

+124
-45
lines changed

4 files changed

+124
-45
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# butcher (development version)
22

3+
* Added methods for `MASS::polr` (@pbulsink, #289).
4+
5+
* Make to work with new versions of xgboost models (#294).
6+
37
* Added butcher methods for `tabnet()` (@cregouby #226).
48

59
* * Added methods for `MASS::polr` (@pbulsink, #289).

R/xgb.R

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ NULL
4343
#' @export
4444
axe_call.xgb.Booster <- function(x, verbose = FALSE, ...) {
4545
old <- x
46-
x <- exchange(x, "call", call("dummy_call"))
46+
if (is.null(attr(x, "call"))) {
47+
x <- exchange(x, "call", call("dummy_call"))
48+
} else {
49+
attr(x, "call") <- call("dummy_call")
50+
}
4751

4852
add_butcher_attributes(
4953
x,
@@ -64,11 +68,15 @@ axe_call.xgb.Booster <- function(x, verbose = FALSE, ...) {
6468
#' @export
6569
axe_env.xgb.Booster <- function(x, verbose = FALSE, ...) {
6670
old <- x
67-
x$callbacks <- purrr::map(x$callbacks,
68-
function(x)
69-
as.function(c(formals(x), body(x)), env = rlang::base_env())
70-
)
71-
71+
if (!is.null(x$callbacks)) {
72+
x$callbacks <-
73+
purrr::map(
74+
x$callbacks,
75+
function(x) {
76+
as.function(c(formals(x), body(x)), env = rlang::base_env())
77+
}
78+
)
79+
}
7280
add_butcher_attributes(
7381
x,
7482
old,

tests/testthat/test-xgb.R

Lines changed: 71 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,42 @@ test_that("xgb.Booster + linear solver + predict() works", {
99
# Load data
1010
data(agaricus.train)
1111
data(agaricus.test)
12-
bst <- xgboost(data = agaricus.train$data,
13-
label = agaricus.train$label,
14-
eta = 1,
15-
nthread = 2,
16-
nrounds = 2,
17-
eval_metric = "logloss",
18-
objective = "binary:logistic",
19-
verbose = 0)
12+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
13+
bst <- xgboost(
14+
x = agaricus.train$data,
15+
y = agaricus.train$label,
16+
learning_rate = 1,
17+
nthread = 2,
18+
nrounds = 2,
19+
eval_metric = "logloss",
20+
objective = "reg:squarederror"
21+
)
22+
} else {
23+
bst <- xgboost(
24+
data = agaricus.train$data,
25+
label = agaricus.train$label,
26+
eta = 1,
27+
nthread = 2,
28+
nrounds = 2,
29+
eval_metric = "logloss",
30+
objective = "binary:logistic",
31+
verbose = 0
32+
)
33+
}
2034
x <- axe_call(bst)
21-
expect_equal(x$call, rlang::expr(dummy_call()))
35+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
36+
extracted_call <- attr(x, "call")
37+
} else {
38+
extracted_call <- x$call
39+
}
40+
expect_equal(extracted_call, rlang::expr(dummy_call()))
2241
x <- axe_env(bst)
23-
expect_lt(lobstr::obj_size(x), lobstr::obj_size(bst))
42+
expect_lte(lobstr::obj_size(x), lobstr::obj_size(bst))
43+
expect_lte(lobstr::obj_size(attributes(x)), lobstr::obj_size(attributes(bst)))
2444
x <- butcher(bst)
25-
expect_equal(xgb.importance(model = x),
26-
xgb.importance(model = bst))
27-
expect_equal(predict(x, agaricus.test$data),
28-
predict(bst, agaricus.test$data))
29-
expect_equal(xgb.dump(x, with_stats = TRUE),
30-
xgb.dump(bst, with_stats = TRUE))
45+
expect_equal(xgb.importance(model = x), xgb.importance(model = bst))
46+
expect_equal(predict(x, agaricus.test$data), predict(bst, agaricus.test$data))
47+
expect_equal(xgb.dump(x, with_stats = TRUE), xgb.dump(bst, with_stats = TRUE))
3148
})
3249

3350
test_that("xgb.Booster + tree-learning algo + predict() works", {
@@ -37,24 +54,45 @@ test_that("xgb.Booster + tree-learning algo + predict() works", {
3754
# Load data
3855
data(agaricus.train)
3956
data(agaricus.test)
40-
dtrain <- xgb.DMatrix(data = agaricus.train$data,
41-
label = agaricus.train$label)
42-
bst <- xgb.train(data = dtrain,
43-
booster = "gblinear",
44-
nthread = 2,
45-
nrounds = 2,
46-
eval_metric = "logloss",
47-
objective = "binary:logistic",
48-
print_every_n = 10000L)
57+
dtrain <- xgb.DMatrix(
58+
data = agaricus.train$data,
59+
label = agaricus.train$label
60+
)
61+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
62+
bst <- xgb.train(
63+
params = list(
64+
booster = "gblinear",
65+
nthread = 2,
66+
eval_metric = "logloss",
67+
objective = "binary:logistic",
68+
print_every_n = 10000L
69+
),
70+
nrounds = 2,
71+
data = dtrain
72+
)
73+
} else {
74+
bst <- xgb.train(
75+
data = dtrain,
76+
booster = "gblinear",
77+
nthread = 2,
78+
nrounds = 2,
79+
eval_metric = "logloss",
80+
objective = "binary:logistic",
81+
print_every_n = 10000L
82+
)
83+
}
4984
x <- axe_call(bst)
50-
expect_equal(x$call, rlang::expr(dummy_call()))
85+
if (utils::packageVersion("xgboost") > "2.0.0.0") {
86+
extracted_call <- attr(x, "call")
87+
} else {
88+
extracted_call <- x$call
89+
}
90+
expect_equal(extracted_call, rlang::expr(dummy_call()))
5191
x <- axe_env(bst)
52-
expect_lt(lobstr::obj_size(x), lobstr::obj_size(bst))
92+
expect_lte(lobstr::obj_size(x), lobstr::obj_size(bst))
93+
expect_lte(lobstr::obj_size(attributes(x)), lobstr::obj_size(attributes(bst)))
5394
x <- butcher(bst)
54-
expect_equal(xgb.importance(model = x),
55-
xgb.importance(model = bst))
56-
expect_equal(predict(x, agaricus.test$data),
57-
predict(bst, agaricus.test$data))
58-
expect_equal(xgb.dump(x, with_stats = TRUE),
59-
xgb.dump(bst, with_stats = TRUE))
95+
expect_equal(xgb.importance(model = x), xgb.importance(model = bst))
96+
expect_equal(predict(x, agaricus.test$data), predict(bst, agaricus.test$data))
97+
expect_equal(xgb.dump(x, with_stats = TRUE), xgb.dump(bst, with_stats = TRUE))
6098
})

tests/testthat/test-xrf.R

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@ test_that("xrf + axe_call() works", {
99
family = 'gaussian'
1010
)
1111
x <- axe_call(res)
12-
expect_equal(x$xgb$call, rlang::expr(dummy_call()))
12+
13+
# due to new xgboost version
14+
# https://github.com/tidymodels/butcher/issues/294
15+
if (is.null(x$xgb$call)) {
16+
expect_equal(attr(x$xgb, "call"), rlang::expr(dummy_call()))
17+
} else {
18+
expect_equal(x$xgb$call, rlang::expr(dummy_call()))
19+
}
20+
1321
expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call()))
1422
expect_equal(x$glm$model$call, rlang::expr(dummy_call()))
1523
})
@@ -24,9 +32,16 @@ test_that("xrf + axe_env() works", {
2432
)
2533
x <- axe_env(res)
2634
expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env())
27-
expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env())
35+
expect_equal(
36+
attr(x$rule_augmented_formula, ".Environment"),
37+
rlang::base_env()
38+
)
2839
expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env())
29-
expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env())
40+
# due to new xgboost version
41+
# https://github.com/tidymodels/butcher/issues/294
42+
if (!is.null(x$xgb$callbacks)) {
43+
expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env())
44+
}
3045
})
3146

3247
test_that("xrf + butcher() works", {
@@ -38,13 +53,27 @@ test_that("xrf + butcher() works", {
3853
family = 'gaussian'
3954
)
4055
x <- butcher(res)
41-
expect_equal(x$xgb$call, rlang::expr(dummy_call()))
56+
# due to new xgboost version
57+
# https://github.com/tidymodels/butcher/issues/294
58+
if (is.null(x$xgb$call)) {
59+
expect_equal(attr(x$xgb, "call"), rlang::expr(dummy_call()))
60+
} else {
61+
expect_equal(x$xgb$call, rlang::expr(dummy_call()))
62+
}
63+
4264
expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call()))
4365
expect_equal(x$glm$model$call, rlang::expr(dummy_call()))
4466
expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env())
45-
expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env())
67+
expect_equal(
68+
attr(x$rule_augmented_formula, ".Environment"),
69+
rlang::base_env()
70+
)
4671
expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env())
47-
expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env())
72+
# due to new xgboost version
73+
# https://github.com/tidymodels/butcher/issues/294
74+
if (!is.null(x$xgb$callbacks)) {
75+
expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env())
76+
}
4877
})
4978

5079
test_that("xrf + predict() works", {

0 commit comments

Comments
 (0)