diff --git a/NEWS.md b/NEWS.md index 17beb45..e1ec286 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # butcher (development version) +* Added methods for `MASS::polr` (@pbulsink, #289). + +* Make to work with new versions of xgboost models (#294). + * Added butcher methods for `tabnet()` (@cregouby #226). * * Added methods for `MASS::polr` (@pbulsink, #289). diff --git a/R/xgb.R b/R/xgb.R index 6f9ec10..53752f8 100644 --- a/R/xgb.R +++ b/R/xgb.R @@ -43,7 +43,11 @@ NULL #' @export axe_call.xgb.Booster <- function(x, verbose = FALSE, ...) { old <- x - x <- exchange(x, "call", call("dummy_call")) + if (is.null(attr(x, "call"))) { + x <- exchange(x, "call", call("dummy_call")) + } else { + attr(x, "call") <- call("dummy_call") + } add_butcher_attributes( x, @@ -64,11 +68,15 @@ axe_call.xgb.Booster <- function(x, verbose = FALSE, ...) { #' @export axe_env.xgb.Booster <- function(x, verbose = FALSE, ...) { old <- x - x$callbacks <- purrr::map(x$callbacks, - function(x) - as.function(c(formals(x), body(x)), env = rlang::base_env()) - ) - + if (!is.null(x$callbacks)) { + x$callbacks <- + purrr::map( + x$callbacks, + function(x) { + as.function(c(formals(x), body(x)), env = rlang::base_env()) + } + ) + } add_butcher_attributes( x, old, diff --git a/tests/testthat/test-xgb.R b/tests/testthat/test-xgb.R index 74bfb39..d6f62b4 100644 --- a/tests/testthat/test-xgb.R +++ b/tests/testthat/test-xgb.R @@ -9,25 +9,42 @@ test_that("xgb.Booster + linear solver + predict() works", { # Load data data(agaricus.train) data(agaricus.test) - bst <- xgboost(data = agaricus.train$data, - label = agaricus.train$label, - eta = 1, - nthread = 2, - nrounds = 2, - eval_metric = "logloss", - objective = "binary:logistic", - verbose = 0) + if (utils::packageVersion("xgboost") > "2.0.0.0") { + bst <- xgboost( + x = agaricus.train$data, + y = agaricus.train$label, + learning_rate = 1, + nthread = 2, + nrounds = 2, + eval_metric = "logloss", + objective = "reg:squarederror" + ) + } else { + bst <- xgboost( + data = agaricus.train$data, + label = agaricus.train$label, + eta = 1, + nthread = 2, + nrounds = 2, + eval_metric = "logloss", + objective = "binary:logistic", + verbose = 0 + ) + } x <- axe_call(bst) - expect_equal(x$call, rlang::expr(dummy_call())) + if (utils::packageVersion("xgboost") > "2.0.0.0") { + extracted_call <- attr(x, "call") + } else { + extracted_call <- x$call + } + expect_equal(extracted_call, rlang::expr(dummy_call())) x <- axe_env(bst) - expect_lt(lobstr::obj_size(x), lobstr::obj_size(bst)) + expect_lte(lobstr::obj_size(x), lobstr::obj_size(bst)) + expect_lte(lobstr::obj_size(attributes(x)), lobstr::obj_size(attributes(bst))) x <- butcher(bst) - expect_equal(xgb.importance(model = x), - xgb.importance(model = bst)) - expect_equal(predict(x, agaricus.test$data), - predict(bst, agaricus.test$data)) - expect_equal(xgb.dump(x, with_stats = TRUE), - xgb.dump(bst, with_stats = TRUE)) + expect_equal(xgb.importance(model = x), xgb.importance(model = bst)) + expect_equal(predict(x, agaricus.test$data), predict(bst, agaricus.test$data)) + expect_equal(xgb.dump(x, with_stats = TRUE), xgb.dump(bst, with_stats = TRUE)) }) test_that("xgb.Booster + tree-learning algo + predict() works", { @@ -37,24 +54,45 @@ test_that("xgb.Booster + tree-learning algo + predict() works", { # Load data data(agaricus.train) data(agaricus.test) - dtrain <- xgb.DMatrix(data = agaricus.train$data, - label = agaricus.train$label) - bst <- xgb.train(data = dtrain, - booster = "gblinear", - nthread = 2, - nrounds = 2, - eval_metric = "logloss", - objective = "binary:logistic", - print_every_n = 10000L) + dtrain <- xgb.DMatrix( + data = agaricus.train$data, + label = agaricus.train$label + ) + if (utils::packageVersion("xgboost") > "2.0.0.0") { + bst <- xgb.train( + params = list( + booster = "gblinear", + nthread = 2, + eval_metric = "logloss", + objective = "binary:logistic", + print_every_n = 10000L + ), + nrounds = 2, + data = dtrain + ) + } else { + bst <- xgb.train( + data = dtrain, + booster = "gblinear", + nthread = 2, + nrounds = 2, + eval_metric = "logloss", + objective = "binary:logistic", + print_every_n = 10000L + ) + } x <- axe_call(bst) - expect_equal(x$call, rlang::expr(dummy_call())) + if (utils::packageVersion("xgboost") > "2.0.0.0") { + extracted_call <- attr(x, "call") + } else { + extracted_call <- x$call + } + expect_equal(extracted_call, rlang::expr(dummy_call())) x <- axe_env(bst) - expect_lt(lobstr::obj_size(x), lobstr::obj_size(bst)) + expect_lte(lobstr::obj_size(x), lobstr::obj_size(bst)) + expect_lte(lobstr::obj_size(attributes(x)), lobstr::obj_size(attributes(bst))) x <- butcher(bst) - expect_equal(xgb.importance(model = x), - xgb.importance(model = bst)) - expect_equal(predict(x, agaricus.test$data), - predict(bst, agaricus.test$data)) - expect_equal(xgb.dump(x, with_stats = TRUE), - xgb.dump(bst, with_stats = TRUE)) + expect_equal(xgb.importance(model = x), xgb.importance(model = bst)) + expect_equal(predict(x, agaricus.test$data), predict(bst, agaricus.test$data)) + expect_equal(xgb.dump(x, with_stats = TRUE), xgb.dump(bst, with_stats = TRUE)) }) diff --git a/tests/testthat/test-xrf.R b/tests/testthat/test-xrf.R index 15f849c..2d1f2cc 100644 --- a/tests/testthat/test-xrf.R +++ b/tests/testthat/test-xrf.R @@ -9,7 +9,15 @@ test_that("xrf + axe_call() works", { family = 'gaussian' ) x <- axe_call(res) - expect_equal(x$xgb$call, rlang::expr(dummy_call())) + + # due to new xgboost version + # https://github.com/tidymodels/butcher/issues/294 + if (is.null(x$xgb$call)) { + expect_equal(attr(x$xgb, "call"), rlang::expr(dummy_call())) + } else { + expect_equal(x$xgb$call, rlang::expr(dummy_call())) + } + expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call())) expect_equal(x$glm$model$call, rlang::expr(dummy_call())) }) @@ -24,9 +32,16 @@ test_that("xrf + axe_env() works", { ) x <- axe_env(res) expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env()) - expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env()) + expect_equal( + attr(x$rule_augmented_formula, ".Environment"), + rlang::base_env() + ) expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env()) - expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) + # due to new xgboost version + # https://github.com/tidymodels/butcher/issues/294 + if (!is.null(x$xgb$callbacks)) { + expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) + } }) test_that("xrf + butcher() works", { @@ -38,13 +53,27 @@ test_that("xrf + butcher() works", { family = 'gaussian' ) x <- butcher(res) - expect_equal(x$xgb$call, rlang::expr(dummy_call())) + # due to new xgboost version + # https://github.com/tidymodels/butcher/issues/294 + if (is.null(x$xgb$call)) { + expect_equal(attr(x$xgb, "call"), rlang::expr(dummy_call())) + } else { + expect_equal(x$xgb$call, rlang::expr(dummy_call())) + } + expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call())) expect_equal(x$glm$model$call, rlang::expr(dummy_call())) expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env()) - expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env()) + expect_equal( + attr(x$rule_augmented_formula, ".Environment"), + rlang::base_env() + ) expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env()) - expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) + # due to new xgboost version + # https://github.com/tidymodels/butcher/issues/294 + if (!is.null(x$xgb$callbacks)) { + expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) + } }) test_that("xrf + predict() works", {