@@ -206,6 +206,19 @@ impl DMatrix {
206
206
DMatrix :: new ( handle)
207
207
}
208
208
209
+
210
+ pub fn load_binary < P : AsRef < Path > > ( path : P ) -> XGBResult < Self > {
211
+ debug ! ( "Loading DMatrix from: {}" , path. as_ref( ) . display( ) ) ;
212
+ let mut handle = ptr:: null_mut ( ) ;
213
+ let fname = ffi:: CString :: new ( path. as_ref ( ) . as_os_str ( ) . as_bytes ( ) ) . unwrap ( ) ;
214
+ xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromFile (
215
+ fname. as_ptr( ) ,
216
+ 1 ,
217
+ & mut handle
218
+ ) ) . unwrap ( ) ;
219
+ DMatrix :: new ( handle)
220
+ }
221
+
209
222
/// Serialise this `DMatrix` as a binary file to given path.
210
223
pub fn save < P : AsRef < Path > > ( & self , path : P ) -> XGBResult < ( ) > {
211
224
debug ! ( "Writing DMatrix to: {}" , path. as_ref( ) . display( ) ) ;
@@ -383,12 +396,10 @@ mod tests {
383
396
let out_path = tmp_dir. path ( ) . join ( "dmat.bin" ) ;
384
397
dmat. save ( & out_path) . unwrap ( ) ;
385
398
386
- let out_path = out_path. to_string_lossy ( ) ;
387
- // let read_path = format!(r#"{{"uri": "{out_path}?format=csv"}}"#);
388
- // let dmat2 = DMatrix::load(&read_path).unwrap();
399
+ let dmat2 = DMatrix :: load_binary ( out_path) . unwrap ( ) ;
389
400
390
- // assert_eq!(dmat.num_rows(), dmat2.num_rows());
391
- // assert_eq!(dmat.num_cols(), dmat2.num_cols());
401
+ assert_eq ! ( dmat. num_rows( ) , dmat2. num_rows( ) ) ;
402
+ assert_eq ! ( dmat. num_cols( ) , dmat2. num_cols( ) ) ;
392
403
// TODO: check contents as well, if possible
393
404
}
394
405
0 commit comments