1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use arrow:: array:: {
19- Array , ArrayRef , AsArray , BooleanArray , Int64Array , ListBuilder , PrimitiveBuilder ,
20- } ;
21- use arrow:: datatypes:: ArrowPrimitiveType ;
18+ use arrow:: array:: { ArrayRef , AsArray , BooleanArray , Int64Array , ListArray , PrimitiveArray } ;
19+ use arrow:: buffer:: OffsetBuffer ;
20+ use arrow:: datatypes:: { ArrowPrimitiveType , Field } ;
2221use datafusion_common:: HashSet ;
2322use datafusion_common:: hash_utils:: RandomState ;
2423use datafusion_expr_common:: groups_accumulator:: { EmitTo , GroupsAccumulator } ;
2524use std:: hash:: Hash ;
2625use std:: mem:: size_of;
2726use std:: sync:: Arc ;
2827
28+ use crate :: aggregate:: groups_accumulator:: accumulate:: accumulate;
29+
2930pub struct PrimitiveDistinctCountGroupsAccumulator < T : ArrowPrimitiveType >
3031where
3132 T :: Native : Eq + Hash ,
3233{
33- /// Count distinct per group.
34- values : Vec < HashSet < T :: Native , RandomState > > ,
34+ seen : HashSet < ( usize , T :: Native ) , RandomState > ,
35+ num_groups : usize ,
3536}
3637
37- impl < T : ArrowPrimitiveType > Default for PrimitiveDistinctCountGroupsAccumulator < T >
38+ impl < T : ArrowPrimitiveType > PrimitiveDistinctCountGroupsAccumulator < T >
3839where
3940 T :: Native : Eq + Hash ,
4041{
41- fn default ( ) -> Self {
42- Self :: new ( )
42+ pub fn new ( ) -> Self {
43+ Self {
44+ seen : HashSet :: default ( ) ,
45+ num_groups : 0 ,
46+ }
47+ }
48+
49+ fn emit_to_values ( & mut self , emit_to : EmitTo ) -> Vec < Vec < T :: Native > > {
50+ let num_emitted = match emit_to {
51+ EmitTo :: All => self . num_groups ,
52+ EmitTo :: First ( n) => n,
53+ } ;
54+
55+ let mut group_values: Vec < Vec < T :: Native > > = vec ! [ Vec :: new( ) ; num_emitted] ;
56+ let mut remaining = HashSet :: default ( ) ;
57+
58+ for ( group_idx, value) in self . seen . drain ( ) {
59+ if group_idx < num_emitted {
60+ group_values[ group_idx] . push ( value) ;
61+ } else {
62+ remaining. insert ( ( group_idx - num_emitted, value) ) ;
63+ }
64+ }
65+
66+ self . seen = remaining;
67+ match emit_to {
68+ EmitTo :: All => self . num_groups = 0 ,
69+ EmitTo :: First ( n) => self . num_groups = self . num_groups . saturating_sub ( n) ,
70+ }
71+
72+ group_values
4373 }
4474}
4575
46- impl < T : ArrowPrimitiveType > PrimitiveDistinctCountGroupsAccumulator < T >
76+ impl < T : ArrowPrimitiveType > Default for PrimitiveDistinctCountGroupsAccumulator < T >
4777where
4878 T :: Native : Eq + Hash ,
4979{
50- pub fn new ( ) -> Self {
51- Self { values : Vec :: new ( ) }
80+ fn default ( ) -> Self {
81+ Self :: new ( )
5282 }
5383}
5484
@@ -64,47 +94,40 @@ where
6494 opt_filter : Option < & BooleanArray > ,
6595 total_num_groups : usize ,
6696 ) -> datafusion_common:: Result < ( ) > {
67- self . values . resize_with ( total_num_groups, HashSet :: default) ;
68- debug_assert_eq ! ( values. len( ) , 1 , "multiple arguments are not supported" ) ;
69-
97+ debug_assert_eq ! ( values. len( ) , 1 ) ;
98+ self . num_groups = self . num_groups . max ( total_num_groups) ;
7099 let arr = values[ 0 ] . as_primitive :: < T > ( ) ;
71-
72- for ( idx, group_idx) in group_indices. iter ( ) . enumerate ( ) {
73- if let Some ( filter) = opt_filter
74- && !filter. value ( idx)
75- {
76- continue ;
77- }
78- if arr. is_valid ( idx) {
79- let value = arr. value ( idx) ;
80- self . values [ * group_idx] . insert ( value) ;
81- }
82- }
83-
100+ accumulate ( group_indices, arr, opt_filter, |group_idx, value| {
101+ self . seen . insert ( ( group_idx, value) ) ;
102+ } ) ;
84103 Ok ( ( ) )
85104 }
86105
87106 fn evaluate ( & mut self , emit_to : EmitTo ) -> datafusion_common:: Result < ArrayRef > {
88- let counts: Vec < i64 > = emit_to
89- . take_needed ( & mut self . values )
90- . iter ( )
91- . map ( |groups| groups. len ( ) as i64 )
92- . collect ( ) ;
93-
107+ let group_values = self . emit_to_values ( emit_to) ;
108+ let counts: Vec < i64 > = group_values. iter ( ) . map ( |v| v. len ( ) as i64 ) . collect ( ) ;
94109 Ok ( Arc :: new ( Int64Array :: from ( counts) ) )
95110 }
96111
97112 fn state ( & mut self , emit_to : EmitTo ) -> datafusion_common:: Result < Vec < ArrayRef > > {
98- let hash_sets = emit_to. take_needed ( & mut self . values ) ;
99- let mut builder = ListBuilder :: new ( PrimitiveBuilder :: < T > :: new ( ) ) ;
113+ let group_values = self . emit_to_values ( emit_to) ;
100114
101- for set in hash_sets {
102- for value in set {
103- builder . values ( ) . append_value ( value ) ;
104- }
105- builder . append ( true ) ;
115+ let mut offsets = vec ! [ 0i32 ] ;
116+ let mut all_values = Vec :: new ( ) ;
117+ for values in & group_values {
118+ all_values . extend ( values . iter ( ) . copied ( ) ) ;
119+ offsets . push ( all_values . len ( ) as i32 ) ;
106120 }
107- Ok ( vec ! [ Arc :: new( builder. finish( ) ) ] )
121+
122+ let values_array = Arc :: new ( PrimitiveArray :: < T > :: from_iter_values ( all_values) ) ;
123+ let list_array = ListArray :: new (
124+ Arc :: new ( Field :: new_list_field ( T :: DATA_TYPE , true ) ) ,
125+ OffsetBuffer :: new ( offsets. into ( ) ) ,
126+ values_array,
127+ None ,
128+ ) ;
129+
130+ Ok ( vec ! [ Arc :: new( list_array) ] )
108131 }
109132
110133 fn merge_batch (
@@ -114,26 +137,23 @@ where
114137 _opt_filter : Option < & BooleanArray > ,
115138 total_num_groups : usize ,
116139 ) -> datafusion_common:: Result < ( ) > {
117- self . values . resize_with ( total_num_groups, HashSet :: default) ;
140+ debug_assert_eq ! ( values. len( ) , 1 ) ;
141+ self . num_groups = self . num_groups . max ( total_num_groups) ;
118142 let list_array = values[ 0 ] . as_list :: < i32 > ( ) ;
119143
120144 for ( row_idx, group_idx) in group_indices. iter ( ) . enumerate ( ) {
121145 let inner = list_array. value ( row_idx) ;
122- let inner_set = inner. as_primitive :: < T > ( ) ;
123- for i in 0 ..inner . len ( ) {
124- self . values [ * group_idx ] . insert ( inner_set . value ( i ) ) ;
146+ let inner_arr = inner. as_primitive :: < T > ( ) ;
147+ for value in inner_arr . values ( ) . iter ( ) {
148+ self . seen . insert ( ( * group_idx , * value ) ) ;
125149 }
126150 }
151+
127152 Ok ( ( ) )
128153 }
129154
130155 fn size ( & self ) -> usize {
131156 size_of :: < Self > ( )
132- + self . values . capacity ( ) * size_of :: < HashSet < T :: Native , RandomState > > ( )
133- + self
134- . values
135- . iter ( )
136- . map ( |s| s. capacity ( ) * size_of :: < T :: Native > ( ) )
137- . sum :: < usize > ( )
157+ + self . seen . capacity ( ) * ( size_of :: < ( usize , T :: Native ) > ( ) + size_of :: < u64 > ( ) )
138158 }
139159}
0 commit comments