@@ -2,14 +2,20 @@ use ndarray::ArrayView1;
22
33use rayon:: iter:: IndexedParallelIterator ;
44use rayon:: prelude:: * ;
5+ use std:: thread:: available_parallelism;
56
67use super :: types:: Num ;
78use num_traits:: { AsPrimitive , FromPrimitive } ;
89
910// ---------------------- Binary search ----------------------
1011
1112// #[inline(always)]
12- fn binary_search < T : PartialOrd > ( arr : ArrayView1 < T > , value : T , left : usize , right : usize ) -> usize {
13+ fn binary_search < T : Copy + PartialOrd > (
14+ arr : ArrayView1 < T > ,
15+ value : T ,
16+ left : usize ,
17+ right : usize ,
18+ ) -> usize {
1319 let mut size: usize = right - left;
1420 let mut left: usize = left;
1521 let mut right: usize = right;
@@ -27,7 +33,7 @@ fn binary_search<T: PartialOrd>(arr: ArrayView1<T>, value: T, left: usize, right
2733}
2834
2935// #[inline(always)]
30- fn binary_search_with_mid < T : PartialOrd > (
36+ fn binary_search_with_mid < T : Copy + PartialOrd > (
3137 arr : ArrayView1 < T > ,
3238 value : T ,
3339 left : usize ,
@@ -69,17 +75,17 @@ where
6975 ( arr[ arr. len ( ) - 1 ] . as_ ( ) / nb_bins as f64 ) - ( arr[ 0 ] . as_ ( ) / nb_bins as f64 ) ;
7076 let idx_step: usize = arr. len ( ) / nb_bins; // used to pre-guess the mid index
7177 let mut value: f64 = arr[ 0 ] . as_ ( ) ; // Search value
72- let mut idx = 0 ; // Index of the search value
78+ let mut idx: usize = 0 ; // Index of the search value
7379 ( 0 ..nb_bins) . map ( move |_| {
74- let start_idx = idx; // Start index of the bin (previous end index)
80+ let start_idx: usize = idx; // Start index of the bin (previous end index)
7581 value += val_step;
76- let mid = idx + idx_step;
82+ let mid: usize = idx + idx_step;
7783 let mid = if mid < arr. len ( ) - 1 {
7884 mid
7985 } else {
8086 arr. len ( ) - 2 // TODO: arr.len() - 1 gives error I thought...
8187 } ;
82- let search_value = T :: from_f64 ( value) . unwrap ( ) ;
88+ let search_value: T = T :: from_f64 ( value) . unwrap ( ) ;
8389 // Implementation WITHOUT pre-guessing mid is slower!!
8490 // idx = binary_search(arr, search_value, idx, arr.len()-1);
8591 idx = binary_search_with_mid ( arr, search_value, idx, arr. len ( ) - 1 , mid) ; // End index of the bin
@@ -102,7 +108,7 @@ fn sequential_add_mul(start_val: f64, add_val: f64, mul: usize) -> f64 {
102108pub ( crate ) fn get_equidistant_bin_idx_iterator_parallel < T > (
103109 arr : ArrayView1 < T > ,
104110 nb_bins : usize ,
105- ) -> impl IndexedParallelIterator < Item = ( usize , usize ) > + ' _
111+ ) -> impl IndexedParallelIterator < Item = impl Iterator < Item = ( usize , usize ) > + ' _ > + ' _
106112where
107113 T : Num + FromPrimitive + AsPrimitive < f64 > + Sync + Send ,
108114{
@@ -111,14 +117,35 @@ where
111117 let val_step: f64 =
112118 ( arr[ arr. len ( ) - 1 ] . as_ ( ) / nb_bins as f64 ) - ( arr[ 0 ] . as_ ( ) / nb_bins as f64 ) ;
113119 let arr0: f64 = arr[ 0 ] . as_ ( ) ;
114- ( 0 ..nb_bins) . into_par_iter ( ) . map ( move |i| {
115- let start_value = sequential_add_mul ( arr0, val_step, i) ;
116- let end_value = start_value + val_step;
117- let start_value = T :: from_f64 ( start_value) . unwrap ( ) ;
118- let end_value = T :: from_f64 ( end_value) . unwrap ( ) ;
119- let start_idx = binary_search ( arr, start_value, 0 , arr. len ( ) - 1 ) ;
120- let end_idx = binary_search ( arr, end_value, start_idx, arr. len ( ) - 1 ) ;
121- ( start_idx, end_idx)
120+ let nb_threads = available_parallelism ( ) . map ( |x| x. get ( ) ) . unwrap_or ( 1 ) ;
121+ let nb_threads = if nb_threads > nb_bins {
122+ nb_bins
123+ } else {
124+ nb_threads
125+ } ;
126+ let nb_bins_per_thread = nb_bins / nb_threads;
127+ let nb_bins_last_thread = nb_bins - nb_bins_per_thread * ( nb_threads - 1 ) ;
128+ // Iterate over the number of threads
129+ // -> for each thread perform the binary search sorted with moving left and
130+ // yield the indices (using the same idea as for the sequential version)
131+ ( 0 ..nb_threads) . into_par_iter ( ) . map ( move |i| {
132+ // Search the start of the fist bin o(f the thread)
133+ let mut value: f64 = sequential_add_mul ( arr0, val_step, i * nb_bins_per_thread) ; // Search value
134+ let start_value: T = T :: from_f64 ( value) . unwrap ( ) ;
135+ let mut idx: usize = binary_search ( arr, start_value, 0 , arr. len ( ) - 1 ) ; // Index of the search value
136+ let nb_bins_thread = if i == nb_threads - 1 {
137+ nb_bins_last_thread
138+ } else {
139+ nb_bins_per_thread
140+ } ;
141+ // Perform sequential binary search for the end of the bins (of the thread)
142+ ( 0 ..nb_bins_thread) . map ( move |_| {
143+ let start_idx: usize = idx; // Start index of the bin (previous end index)
144+ value += val_step;
145+ let search_value: T = T :: from_f64 ( value) . unwrap ( ) ;
146+ idx = binary_search ( arr, search_value, idx, arr. len ( ) - 1 ) ; // End index of the bin
147+ ( start_idx, idx)
148+ } )
122149 } )
123150}
124151
@@ -207,7 +234,10 @@ mod tests {
207234 let bin_idxs = bin_idxs_iter. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) ;
208235 assert_eq ! ( bin_idxs, vec![ 0 , 3 , 6 ] ) ;
209236 let bin_idxs_iter = get_equidistant_bin_idx_iterator_parallel ( arr. view ( ) , 3 ) ;
210- let bin_idxs = bin_idxs_iter. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) ;
237+ let bin_idxs = bin_idxs_iter
238+ . map ( |x| x. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) )
239+ . flatten ( )
240+ . collect :: < Vec < usize > > ( ) ;
211241 assert_eq ! ( bin_idxs, vec![ 0 , 3 , 6 ] ) ;
212242 }
213243
@@ -225,7 +255,10 @@ mod tests {
225255 let bin_idxs_iter = get_equidistant_bin_idx_iterator ( arr. view ( ) , nb_bins) ;
226256 let bin_idxs = bin_idxs_iter. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) ;
227257 let bin_idxs_iter = get_equidistant_bin_idx_iterator_parallel ( arr. view ( ) , nb_bins) ;
228- let bin_idxs_parallel = bin_idxs_iter. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) ;
258+ let bin_idxs_parallel = bin_idxs_iter
259+ . map ( |x| x. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) )
260+ . flatten ( )
261+ . collect :: < Vec < usize > > ( ) ;
229262 assert_eq ! ( bin_idxs, bin_idxs_parallel) ;
230263 }
231264 }
0 commit comments