Skip to content

Commit 60074af

Browse files
montanalowMontana Low
authored and
Montana Low
committed
tests pass
1 parent 247b262 commit 60074af

File tree

6 files changed

+27
-36
lines changed

6 files changed

+27
-36
lines changed

src/booster.rs

+20-10
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,8 @@ mod tests {
761761

762762
#[test]
763763
fn save_and_load_from_buffer() {
764-
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
764+
let dmat_train =
765+
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
765766
let mut booster = Booster::new_with_cached_dmats(&BoosterParameters::default(), &[&dmat_train]).unwrap();
766767
let attr = booster.get_attribute("foo").expect("Getting attribute failed");
767768
assert_eq!(attr, None);
@@ -804,8 +805,10 @@ mod tests {
804805

805806
#[test]
806807
fn predict() {
807-
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
808-
let dmat_test =DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
808+
let dmat_train =
809+
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
810+
let dmat_test =
811+
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
809812

810813
let tree_params = tree::TreeBoosterParametersBuilder::default()
811814
.max_depth(2)
@@ -886,8 +889,10 @@ mod tests {
886889

887890
#[test]
888891
fn predict_leaf() {
889-
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
890-
let dmat_test = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
892+
let dmat_train =
893+
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
894+
let dmat_test =
895+
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
891896

892897
let tree_params = tree::TreeBoosterParametersBuilder::default()
893898
.max_depth(2)
@@ -919,8 +924,10 @@ mod tests {
919924

920925
#[test]
921926
fn predict_contributions() {
922-
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
923-
let dmat_test = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
927+
let dmat_train =
928+
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
929+
let dmat_test =
930+
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
924931

925932
let tree_params = tree::TreeBoosterParametersBuilder::default()
926933
.max_depth(2)
@@ -953,8 +960,10 @@ mod tests {
953960

954961
#[test]
955962
fn predict_interactions() {
956-
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
957-
let dmat_test = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
963+
let dmat_train =
964+
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
965+
let dmat_test =
966+
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
958967

959968
let tree_params = tree::TreeBoosterParametersBuilder::default()
960969
.max_depth(2)
@@ -1005,7 +1014,8 @@ mod tests {
10051014

10061015
#[test]
10071016
fn dump_model() {
1008-
let dmat_train = DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
1017+
let dmat_train =
1018+
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
10091019

10101020
println!("{:?}", dmat_train.shape());
10111021

src/dmatrix.rs

+2-10
Original file line numberDiff line numberDiff line change
@@ -199,23 +199,15 @@ impl DMatrix {
199199
debug!("Loading DMatrix from: {}", path.as_ref().display());
200200
let mut handle = ptr::null_mut();
201201
let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
202-
xgb_call!(xgboost_sys::XGDMatrixCreateFromURI(
203-
fname.as_ptr(),
204-
&mut handle
205-
))?;
202+
xgb_call!(xgboost_sys::XGDMatrixCreateFromURI(fname.as_ptr(), &mut handle))?;
206203
DMatrix::new(handle)
207204
}
208205

209-
210206
pub fn load_binary<P: AsRef<Path>>(path: P) -> XGBResult<Self> {
211207
debug!("Loading DMatrix from: {}", path.as_ref().display());
212208
let mut handle = ptr::null_mut();
213209
let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
214-
xgb_call!(xgboost_sys::XGDMatrixCreateFromFile(
215-
fname.as_ptr(),
216-
1,
217-
&mut handle
218-
)).unwrap();
210+
xgb_call!(xgboost_sys::XGDMatrixCreateFromFile(fname.as_ptr(), 1, &mut handle)).unwrap();
219211
DMatrix::new(handle)
220212
}
221213

src/parameters/dart.rs

+3-9
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ use std::default::Default;
66
use super::Interval;
77

88
/// Type of sampling algorithm.
9-
#[derive(Clone)]
10-
#[derive(Default)]
9+
#[derive(Clone, Default)]
1110
pub enum SampleType {
1211
/// Dropped trees are selected uniformly.
1312
#[default]
@@ -26,11 +25,8 @@ impl ToString for SampleType {
2625
}
2726
}
2827

29-
30-
3128
/// Type of normalization algorithm.
32-
#[derive(Clone)]
33-
#[derive(Default)]
29+
#[derive(Clone, Default)]
3430
pub enum NormalizeType {
3531
/// New trees have the same weight of each of dropped trees.
3632
/// * weight of new trees are 1 / (k + learning_rate)
@@ -54,8 +50,6 @@ impl ToString for NormalizeType {
5450
}
5551
}
5652

57-
58-
5953
/// Additional parameters for Dart Booster.
6054
#[derive(Builder, Clone)]
6155
#[builder(build_fn(validate = "Self::validate"))]
@@ -102,7 +96,7 @@ impl DartBoosterParameters {
10296
("normalize_type".to_owned(), self.normalize_type.to_string()),
10397
("rate_drop".to_owned(), self.rate_drop.to_string()),
10498
("one_drop".to_owned(), (self.one_drop as u8).to_string()),
105-
("skip_drop".to_owned(), self.skip_drop.to_string())
99+
("skip_drop".to_owned(), self.skip_drop.to_string()),
106100
]
107101
}
108102
}

src/parameters/learning.rs

-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ impl ToString for Objective {
100100
}
101101
}
102102

103-
104-
105103
/// Type of evaluation metrics to use during learning.
106104
#[derive(Clone)]
107105
pub enum Metrics {

src/parameters/linear.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
use std::default::Default;
44

55
/// Linear model algorithm.
6-
#[derive(Clone)]
7-
#[derive(Default)]
6+
#[derive(Clone, Default)]
87
pub enum LinearUpdate {
98
/// Parallel coordinate descent algorithm based on shotgun algorithm. Uses ‘hogwild’ parallelism and
109
/// therefore produces a nondeterministic solution on each run.
@@ -24,8 +23,6 @@ impl ToString for LinearUpdate {
2423
}
2524
}
2625

27-
28-
2926
/// BoosterParameters for Linear Booster.
3027
#[derive(Builder, Clone)]
3128
#[builder(default)]

src/parameters/tree.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ impl Default for TreeBoosterParameters {
357357

358358
impl TreeBoosterParameters {
359359
pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> {
360-
let mut v = vec! [
360+
let mut v = vec![
361361
("booster".to_owned(), "gbtree".to_owned()),
362362
("eta".to_owned(), self.eta.to_string()),
363363
("gamma".to_owned(), self.gamma.to_string()),

0 commit comments

Comments
 (0)