Skip to content

Commit 247b262

Browse files
montanalowMontana Low
authored and
Montana Low
committed
binary
1 parent e33922d commit 247b262

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

src/dmatrix.rs

+16-5
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,19 @@ impl DMatrix {
206206
DMatrix::new(handle)
207207
}
208208

209+
210+
pub fn load_binary<P: AsRef<Path>>(path: P) -> XGBResult<Self> {
211+
debug!("Loading DMatrix from: {}", path.as_ref().display());
212+
let mut handle = ptr::null_mut();
213+
let fname = ffi::CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
214+
xgb_call!(xgboost_sys::XGDMatrixCreateFromFile(
215+
fname.as_ptr(),
216+
1,
217+
&mut handle
218+
)).unwrap();
219+
DMatrix::new(handle)
220+
}
221+
209222
/// Serialise this `DMatrix` as a binary file to given path.
210223
pub fn save<P: AsRef<Path>>(&self, path: P) -> XGBResult<()> {
211224
debug!("Writing DMatrix to: {}", path.as_ref().display());
@@ -383,12 +396,10 @@ mod tests {
383396
let out_path = tmp_dir.path().join("dmat.bin");
384397
dmat.save(&out_path).unwrap();
385398

386-
let out_path = out_path.to_string_lossy();
387-
// let read_path = format!(r#"{{"uri": "{out_path}?format=csv"}}"#);
388-
// let dmat2 = DMatrix::load(&read_path).unwrap();
399+
let dmat2 = DMatrix::load_binary(out_path).unwrap();
389400

390-
// assert_eq!(dmat.num_rows(), dmat2.num_rows());
391-
// assert_eq!(dmat.num_cols(), dmat2.num_cols());
401+
assert_eq!(dmat.num_rows(), dmat2.num_rows());
402+
assert_eq!(dmat.num_cols(), dmat2.num_cols());
392403
// TODO: check contents as well, if possible
393404
}
394405

0 commit comments

Comments
 (0)