From 8b1b4681e4b1338c6b06dfda743084a2fc3674bd Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 20 Nov 2025 18:23:51 -0800 Subject: [PATCH 1/5] update xgb.Booster methods to work with newer versions of butcher --- R/xgb.R | 10 ++++- tests/testthat/test-xgb.R | 78 ++++++++++++++++++++++++++++----------- 2 files changed, 65 insertions(+), 23 deletions(-) diff --git a/R/xgb.R b/R/xgb.R index 6f9ec10d..f9c2a157 100644 --- a/R/xgb.R +++ b/R/xgb.R @@ -43,7 +43,11 @@ NULL #' @export axe_call.xgb.Booster <- function(x, verbose = FALSE, ...) { old <- x - x <- exchange(x, "call", call("dummy_call")) + if (is.null(attr(x, "call"))) { + x <- exchange(x, "call", call("dummy_call")) + } else { + attr(x, "call") <- call("dummy_call") + } add_butcher_attributes( x, @@ -64,10 +68,12 @@ axe_call.xgb.Booster <- function(x, verbose = FALSE, ...) { #' @export axe_env.xgb.Booster <- function(x, verbose = FALSE, ...) { old <- x - x$callbacks <- purrr::map(x$callbacks, + if (!is.null(x$callbacks)) { + x$callbacks <- purrr::map(x$callbacks, function(x) as.function(c(formals(x), body(x)), env = rlang::base_env()) ) + } add_butcher_attributes( x, diff --git a/tests/testthat/test-xgb.R b/tests/testthat/test-xgb.R index 74bfb39a..cd80407e 100644 --- a/tests/testthat/test-xgb.R +++ b/tests/testthat/test-xgb.R @@ -9,18 +9,34 @@ test_that("xgb.Booster + linear solver + predict() works", { # Load data data(agaricus.train) data(agaricus.test) - bst <- xgboost(data = agaricus.train$data, - label = agaricus.train$label, - eta = 1, - nthread = 2, - nrounds = 2, - eval_metric = "logloss", - objective = "binary:logistic", - verbose = 0) + if (utils::packageVersion("xgboost") > "2.0.0.0") { + bst <- xgboost(x = agaricus.train$data, + y = agaricus.train$label, + learning_rate = 1, + nthread = 2, + nrounds = 2, + eval_metric = "logloss", + objective = "reg:squarederror") + } else { + bst <- xgboost(data = agaricus.train$data, + label = agaricus.train$label, + eta = 1, + nthread = 2, + nrounds = 2, + eval_metric = "logloss", + objective = "binary:logistic", + verbose = 0) + } x <- axe_call(bst) - expect_equal(x$call, rlang::expr(dummy_call())) + if (utils::packageVersion("xgboost") > "2.0.0.0") { + extracted_call <- attr(x, "call") + } else { + extracted_call <- x$call + } + expect_equal(extracted_call, rlang::expr(dummy_call())) x <- axe_env(bst) - expect_lt(lobstr::obj_size(x), lobstr::obj_size(bst)) + expect_lte(lobstr::obj_size(x), lobstr::obj_size(bst)) + expect_lte(lobstr::obj_size(attributes(x)), lobstr::obj_size(attributes(bst))) x <- butcher(bst) expect_equal(xgb.importance(model = x), xgb.importance(model = bst)) @@ -28,7 +44,7 @@ test_that("xgb.Booster + linear solver + predict() works", { predict(bst, agaricus.test$data)) expect_equal(xgb.dump(x, with_stats = TRUE), xgb.dump(bst, with_stats = TRUE)) -}) + }) test_that("xgb.Booster + tree-learning algo + predict() works", { skip_on_cran() @@ -39,17 +55,37 @@ test_that("xgb.Booster + tree-learning algo + predict() works", { data(agaricus.test) dtrain <- xgb.DMatrix(data = agaricus.train$data, label = agaricus.train$label) - bst <- xgb.train(data = dtrain, - booster = "gblinear", - nthread = 2, - nrounds = 2, - eval_metric = "logloss", - objective = "binary:logistic", - print_every_n = 10000L) + if (utils::packageVersion("xgboost") > "2.0.0.0") { + bst <- xgb.train( + params = list( + booster = "gblinear", + nthread = 2, + eval_metric = "logloss", + objective = "binary:logistic", + print_every_n = 10000L + ), + nrounds = 2, + data = dtrain + ) + } else { + bst <- xgb.train(data = dtrain, + booster = "gblinear", + nthread = 2, + nrounds = 2, + eval_metric = "logloss", + objective = "binary:logistic", + print_every_n = 10000L) + } x <- axe_call(bst) - expect_equal(x$call, rlang::expr(dummy_call())) + if (utils::packageVersion("xgboost") > "2.0.0.0") { + extracted_call <- attr(x, "call") + } else { + extracted_call <- x$call + } + expect_equal(extracted_call, rlang::expr(dummy_call())) x <- axe_env(bst) - expect_lt(lobstr::obj_size(x), lobstr::obj_size(bst)) + expect_lte(lobstr::obj_size(x), lobstr::obj_size(bst)) + expect_lte(lobstr::obj_size(attributes(x)), lobstr::obj_size(attributes(bst))) x <- butcher(bst) expect_equal(xgb.importance(model = x), xgb.importance(model = bst)) @@ -57,4 +93,4 @@ test_that("xgb.Booster + tree-learning algo + predict() works", { predict(bst, agaricus.test$data)) expect_equal(xgb.dump(x, with_stats = TRUE), xgb.dump(bst, with_stats = TRUE)) -}) + }) From 51b4a32319a82ec382a5f6a5699a7c79d952d130 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 20 Nov 2025 18:47:04 -0800 Subject: [PATCH 2/5] make test-xrf.R xgboost proof --- tests/testthat/test-xrf.R | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test-xrf.R b/tests/testthat/test-xrf.R index 15f849c8..a21e8074 100644 --- a/tests/testthat/test-xrf.R +++ b/tests/testthat/test-xrf.R @@ -9,7 +9,15 @@ test_that("xrf + axe_call() works", { family = 'gaussian' ) x <- axe_call(res) - expect_equal(x$xgb$call, rlang::expr(dummy_call())) + + # due to new xgboost version + # https://github.com/tidymodels/butcher/issues/294 + if (is.null(x$xgb$call)) { + expect_equal(attr(x$xgb, "call"), rlang::expr(dummy_call())) + } else { + expect_equal(x$xgb$call, rlang::expr(dummy_call())) + } + expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call())) expect_equal(x$glm$model$call, rlang::expr(dummy_call())) }) @@ -26,7 +34,12 @@ test_that("xrf + axe_env() works", { expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env()) expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env()) expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env()) - expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) + + # due to new xgboost version + # https://github.com/tidymodels/butcher/issues/294 + if (!is.null(x$xgb$callbacks)) { + expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) + } }) test_that("xrf + butcher() works", { @@ -38,13 +51,26 @@ test_that("xrf + butcher() works", { family = 'gaussian' ) x <- butcher(res) - expect_equal(x$xgb$call, rlang::expr(dummy_call())) + + # due to new xgboost version + # https://github.com/tidymodels/butcher/issues/294 + if (is.null(x$xgb$call)) { + expect_equal(attr(x$xgb, "call"), rlang::expr(dummy_call())) + } else { + expect_equal(x$xgb$call, rlang::expr(dummy_call())) + } + expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call())) expect_equal(x$glm$model$call, rlang::expr(dummy_call())) expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env()) expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env()) expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env()) - expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) + + # due to new xgboost version + # https://github.com/tidymodels/butcher/issues/294 + if (!is.null(x$xgb$callbacks)) { + expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) + } }) test_that("xrf + predict() works", { From 3c64253b3be7d00e934f1d2ec88948a661e069e9 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 20 Nov 2025 19:45:45 -0800 Subject: [PATCH 3/5] add news --- NEWS.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 81dc0d71..c483fe64 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,8 @@ # butcher (development version) -* * Added methods for `MASS::polr` (@pbulsink, #289). +* Added methods for `MASS::polr` (@pbulsink, #289). + +* Make to work with new versions of xgboost models (#294). # butcher 0.3.6 From ba1aafc2187c71634c4d827376fd130bc0dc45fd Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Wed, 3 Dec 2025 07:59:45 -0500 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- tests/testthat/test-xgb.R | 60 ++++++++++++++++++++++----------------- tests/testthat/test-xrf.R | 13 +++++---- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/tests/testthat/test-xgb.R b/tests/testthat/test-xgb.R index cd80407e..03c53a3d 100644 --- a/tests/testthat/test-xgb.R +++ b/tests/testthat/test-xgb.R @@ -10,26 +10,30 @@ test_that("xgb.Booster + linear solver + predict() works", { data(agaricus.train) data(agaricus.test) if (utils::packageVersion("xgboost") > "2.0.0.0") { - bst <- xgboost(x = agaricus.train$data, - y = agaricus.train$label, - learning_rate = 1, - nthread = 2, - nrounds = 2, - eval_metric = "logloss", - objective = "reg:squarederror") + bst <- xgboost( + x = agaricus.train$data, + y = agaricus.train$label, + learning_rate = 1, + nthread = 2, + nrounds = 2, + eval_metric = "logloss", + objective = "reg:squarederror" + ) } else { - bst <- xgboost(data = agaricus.train$data, - label = agaricus.train$label, - eta = 1, - nthread = 2, - nrounds = 2, - eval_metric = "logloss", - objective = "binary:logistic", - verbose = 0) + bst <- xgboost( + data = agaricus.train$data, + label = agaricus.train$label, + eta = 1, + nthread = 2, + nrounds = 2, + eval_metric = "logloss", + objective = "binary:logistic", + verbose = 0 + ) } x <- axe_call(bst) if (utils::packageVersion("xgboost") > "2.0.0.0") { - extracted_call <- attr(x, "call") + extracted_call <- attr(x, "call") } else { extracted_call <- x$call } @@ -53,8 +57,10 @@ test_that("xgb.Booster + tree-learning algo + predict() works", { # Load data data(agaricus.train) data(agaricus.test) - dtrain <- xgb.DMatrix(data = agaricus.train$data, - label = agaricus.train$label) + dtrain <- xgb.DMatrix( + data = agaricus.train$data, + label = agaricus.train$label + ) if (utils::packageVersion("xgboost") > "2.0.0.0") { bst <- xgb.train( params = list( @@ -68,17 +74,19 @@ test_that("xgb.Booster + tree-learning algo + predict() works", { data = dtrain ) } else { - bst <- xgb.train(data = dtrain, - booster = "gblinear", - nthread = 2, - nrounds = 2, - eval_metric = "logloss", - objective = "binary:logistic", - print_every_n = 10000L) + bst <- xgb.train( + data = dtrain, + booster = "gblinear", + nthread = 2, + nrounds = 2, + eval_metric = "logloss", + objective = "binary:logistic", + print_every_n = 10000L + ) } x <- axe_call(bst) if (utils::packageVersion("xgboost") > "2.0.0.0") { - extracted_call <- attr(x, "call") + extracted_call <- attr(x, "call") } else { extracted_call <- x$call } diff --git a/tests/testthat/test-xrf.R b/tests/testthat/test-xrf.R index a21e8074..2d1f2cc0 100644 --- a/tests/testthat/test-xrf.R +++ b/tests/testthat/test-xrf.R @@ -32,9 +32,11 @@ test_that("xrf + axe_env() works", { ) x <- axe_env(res) expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env()) - expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env()) + expect_equal( + attr(x$rule_augmented_formula, ".Environment"), + rlang::base_env() + ) expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env()) - # due to new xgboost version # https://github.com/tidymodels/butcher/issues/294 if (!is.null(x$xgb$callbacks)) { @@ -51,7 +53,6 @@ test_that("xrf + butcher() works", { family = 'gaussian' ) x <- butcher(res) - # due to new xgboost version # https://github.com/tidymodels/butcher/issues/294 if (is.null(x$xgb$call)) { @@ -63,9 +64,11 @@ test_that("xrf + butcher() works", { expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call())) expect_equal(x$glm$model$call, rlang::expr(dummy_call())) expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env()) - expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env()) + expect_equal( + attr(x$rule_augmented_formula, ".Environment"), + rlang::base_env() + ) expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env()) - # due to new xgboost version # https://github.com/tidymodels/butcher/issues/294 if (!is.null(x$xgb$callbacks)) { From 240935ea1e76beb345095c20a082c57b65344c17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= Date: Wed, 3 Dec 2025 08:02:09 -0500 Subject: [PATCH 5/5] more formatting --- R/xgb.R | 12 +++++++----- tests/testthat/test-xgb.R | 22 ++++++++-------------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/R/xgb.R b/R/xgb.R index f9c2a157..53752f84 100644 --- a/R/xgb.R +++ b/R/xgb.R @@ -69,12 +69,14 @@ axe_call.xgb.Booster <- function(x, verbose = FALSE, ...) { axe_env.xgb.Booster <- function(x, verbose = FALSE, ...) { old <- x if (!is.null(x$callbacks)) { - x$callbacks <- purrr::map(x$callbacks, - function(x) - as.function(c(formals(x), body(x)), env = rlang::base_env()) - ) + x$callbacks <- + purrr::map( + x$callbacks, + function(x) { + as.function(c(formals(x), body(x)), env = rlang::base_env()) + } + ) } - add_butcher_attributes( x, old, diff --git a/tests/testthat/test-xgb.R b/tests/testthat/test-xgb.R index 03c53a3d..d6f62b4d 100644 --- a/tests/testthat/test-xgb.R +++ b/tests/testthat/test-xgb.R @@ -42,13 +42,10 @@ test_that("xgb.Booster + linear solver + predict() works", { expect_lte(lobstr::obj_size(x), lobstr::obj_size(bst)) expect_lte(lobstr::obj_size(attributes(x)), lobstr::obj_size(attributes(bst))) x <- butcher(bst) - expect_equal(xgb.importance(model = x), - xgb.importance(model = bst)) - expect_equal(predict(x, agaricus.test$data), - predict(bst, agaricus.test$data)) - expect_equal(xgb.dump(x, with_stats = TRUE), - xgb.dump(bst, with_stats = TRUE)) - }) + expect_equal(xgb.importance(model = x), xgb.importance(model = bst)) + expect_equal(predict(x, agaricus.test$data), predict(bst, agaricus.test$data)) + expect_equal(xgb.dump(x, with_stats = TRUE), xgb.dump(bst, with_stats = TRUE)) +}) test_that("xgb.Booster + tree-learning algo + predict() works", { skip_on_cran() @@ -95,10 +92,7 @@ test_that("xgb.Booster + tree-learning algo + predict() works", { expect_lte(lobstr::obj_size(x), lobstr::obj_size(bst)) expect_lte(lobstr::obj_size(attributes(x)), lobstr::obj_size(attributes(bst))) x <- butcher(bst) - expect_equal(xgb.importance(model = x), - xgb.importance(model = bst)) - expect_equal(predict(x, agaricus.test$data), - predict(bst, agaricus.test$data)) - expect_equal(xgb.dump(x, with_stats = TRUE), - xgb.dump(bst, with_stats = TRUE)) - }) + expect_equal(xgb.importance(model = x), xgb.importance(model = bst)) + expect_equal(predict(x, agaricus.test$data), predict(bst, agaricus.test$data)) + expect_equal(xgb.dump(x, with_stats = TRUE), xgb.dump(bst, with_stats = TRUE)) +})