@@ -9,25 +9,42 @@ test_that("xgb.Booster + linear solver + predict() works", {
99 # Load data
1010 data(agaricus.train )
1111 data(agaricus.test )
12- bst <- xgboost(data = agaricus.train $ data ,
13- label = agaricus.train $ label ,
14- eta = 1 ,
15- nthread = 2 ,
16- nrounds = 2 ,
17- eval_metric = " logloss" ,
18- objective = " binary:logistic" ,
19- verbose = 0 )
12+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
13+ bst <- xgboost(
14+ x = agaricus.train $ data ,
15+ y = agaricus.train $ label ,
16+ learning_rate = 1 ,
17+ nthread = 2 ,
18+ nrounds = 2 ,
19+ eval_metric = " logloss" ,
20+ objective = " reg:squarederror"
21+ )
22+ } else {
23+ bst <- xgboost(
24+ data = agaricus.train $ data ,
25+ label = agaricus.train $ label ,
26+ eta = 1 ,
27+ nthread = 2 ,
28+ nrounds = 2 ,
29+ eval_metric = " logloss" ,
30+ objective = " binary:logistic" ,
31+ verbose = 0
32+ )
33+ }
2034 x <- axe_call(bst )
21- expect_equal(x $ call , rlang :: expr(dummy_call()))
35+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
36+ extracted_call <- attr(x , " call" )
37+ } else {
38+ extracted_call <- x $ call
39+ }
40+ expect_equal(extracted_call , rlang :: expr(dummy_call()))
2241 x <- axe_env(bst )
23- expect_lt(lobstr :: obj_size(x ), lobstr :: obj_size(bst ))
42+ expect_lte(lobstr :: obj_size(x ), lobstr :: obj_size(bst ))
43+ expect_lte(lobstr :: obj_size(attributes(x )), lobstr :: obj_size(attributes(bst )))
2444 x <- butcher(bst )
25- expect_equal(xgb.importance(model = x ),
26- xgb.importance(model = bst ))
27- expect_equal(predict(x , agaricus.test $ data ),
28- predict(bst , agaricus.test $ data ))
29- expect_equal(xgb.dump(x , with_stats = TRUE ),
30- xgb.dump(bst , with_stats = TRUE ))
45+ expect_equal(xgb.importance(model = x ), xgb.importance(model = bst ))
46+ expect_equal(predict(x , agaricus.test $ data ), predict(bst , agaricus.test $ data ))
47+ expect_equal(xgb.dump(x , with_stats = TRUE ), xgb.dump(bst , with_stats = TRUE ))
3148})
3249
3350test_that(" xgb.Booster + tree-learning algo + predict() works" , {
@@ -37,24 +54,45 @@ test_that("xgb.Booster + tree-learning algo + predict() works", {
3754 # Load data
3855 data(agaricus.train )
3956 data(agaricus.test )
40- dtrain <- xgb.DMatrix(data = agaricus.train $ data ,
41- label = agaricus.train $ label )
42- bst <- xgb.train(data = dtrain ,
43- booster = " gblinear" ,
44- nthread = 2 ,
45- nrounds = 2 ,
46- eval_metric = " logloss" ,
47- objective = " binary:logistic" ,
48- print_every_n = 10000L )
57+ dtrain <- xgb.DMatrix(
58+ data = agaricus.train $ data ,
59+ label = agaricus.train $ label
60+ )
61+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
62+ bst <- xgb.train(
63+ params = list (
64+ booster = " gblinear" ,
65+ nthread = 2 ,
66+ eval_metric = " logloss" ,
67+ objective = " binary:logistic" ,
68+ print_every_n = 10000L
69+ ),
70+ nrounds = 2 ,
71+ data = dtrain
72+ )
73+ } else {
74+ bst <- xgb.train(
75+ data = dtrain ,
76+ booster = " gblinear" ,
77+ nthread = 2 ,
78+ nrounds = 2 ,
79+ eval_metric = " logloss" ,
80+ objective = " binary:logistic" ,
81+ print_every_n = 10000L
82+ )
83+ }
4984 x <- axe_call(bst )
50- expect_equal(x $ call , rlang :: expr(dummy_call()))
85+ if (utils :: packageVersion(" xgboost" ) > " 2.0.0.0" ) {
86+ extracted_call <- attr(x , " call" )
87+ } else {
88+ extracted_call <- x $ call
89+ }
90+ expect_equal(extracted_call , rlang :: expr(dummy_call()))
5191 x <- axe_env(bst )
52- expect_lt(lobstr :: obj_size(x ), lobstr :: obj_size(bst ))
92+ expect_lte(lobstr :: obj_size(x ), lobstr :: obj_size(bst ))
93+ expect_lte(lobstr :: obj_size(attributes(x )), lobstr :: obj_size(attributes(bst )))
5394 x <- butcher(bst )
54- expect_equal(xgb.importance(model = x ),
55- xgb.importance(model = bst ))
56- expect_equal(predict(x , agaricus.test $ data ),
57- predict(bst , agaricus.test $ data ))
58- expect_equal(xgb.dump(x , with_stats = TRUE ),
59- xgb.dump(bst , with_stats = TRUE ))
95+ expect_equal(xgb.importance(model = x ), xgb.importance(model = bst ))
96+ expect_equal(predict(x , agaricus.test $ data ), predict(bst , agaricus.test $ data ))
97+ expect_equal(xgb.dump(x , with_stats = TRUE ), xgb.dump(bst , with_stats = TRUE ))
6098})
0 commit comments