Skip to content

Commit 5692893

Browse files
committed
implement_group_accumulators_count_distinct_use_hashtable
1 parent f0b2a4a commit 5692893

File tree

1 file changed

+72
-52
lines changed
  • datafusion/functions-aggregate-common/src/aggregate/count_distinct

1 file changed

+72
-52
lines changed

datafusion/functions-aggregate-common/src/aggregate/count_distinct/groups.rs

Lines changed: 72 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -15,40 +15,70 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{
19-
Array, ArrayRef, AsArray, BooleanArray, Int64Array, ListBuilder, PrimitiveBuilder,
20-
};
21-
use arrow::datatypes::ArrowPrimitiveType;
18+
use arrow::array::{ArrayRef, AsArray, BooleanArray, Int64Array, ListArray, PrimitiveArray};
19+
use arrow::buffer::OffsetBuffer;
20+
use arrow::datatypes::{ArrowPrimitiveType, Field};
2221
use datafusion_common::HashSet;
2322
use datafusion_common::hash_utils::RandomState;
2423
use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
2524
use std::hash::Hash;
2625
use std::mem::size_of;
2726
use std::sync::Arc;
2827

28+
use crate::aggregate::groups_accumulator::accumulate::accumulate;
29+
2930
pub struct PrimitiveDistinctCountGroupsAccumulator<T: ArrowPrimitiveType>
3031
where
3132
T::Native: Eq + Hash,
3233
{
33-
/// Count distinct per group.
34-
values: Vec<HashSet<T::Native, RandomState>>,
34+
seen: HashSet<(usize, T::Native), RandomState>,
35+
num_groups: usize,
3536
}
3637

37-
impl<T: ArrowPrimitiveType> Default for PrimitiveDistinctCountGroupsAccumulator<T>
38+
impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
3839
where
3940
T::Native: Eq + Hash,
4041
{
41-
fn default() -> Self {
42-
Self::new()
42+
pub fn new() -> Self {
43+
Self {
44+
seen: HashSet::default(),
45+
num_groups: 0,
46+
}
47+
}
48+
49+
fn emit_to_values(&mut self, emit_to: EmitTo) -> Vec<Vec<T::Native>> {
50+
let num_emitted = match emit_to {
51+
EmitTo::All => self.num_groups,
52+
EmitTo::First(n) => n,
53+
};
54+
55+
let mut group_values: Vec<Vec<T::Native>> = vec![Vec::new(); num_emitted];
56+
let mut remaining = HashSet::default();
57+
58+
for (group_idx, value) in self.seen.drain() {
59+
if group_idx < num_emitted {
60+
group_values[group_idx].push(value);
61+
} else {
62+
remaining.insert((group_idx - num_emitted, value));
63+
}
64+
}
65+
66+
self.seen = remaining;
67+
match emit_to {
68+
EmitTo::All => self.num_groups = 0,
69+
EmitTo::First(n) => self.num_groups = self.num_groups.saturating_sub(n),
70+
}
71+
72+
group_values
4373
}
4474
}
4575

46-
impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
76+
impl<T: ArrowPrimitiveType> Default for PrimitiveDistinctCountGroupsAccumulator<T>
4777
where
4878
T::Native: Eq + Hash,
4979
{
50-
pub fn new() -> Self {
51-
Self { values: Vec::new() }
80+
fn default() -> Self {
81+
Self::new()
5282
}
5383
}
5484

@@ -64,47 +94,40 @@ where
6494
opt_filter: Option<&BooleanArray>,
6595
total_num_groups: usize,
6696
) -> datafusion_common::Result<()> {
67-
self.values.resize_with(total_num_groups, HashSet::default);
68-
debug_assert_eq!(values.len(), 1, "multiple arguments are not supported");
69-
97+
debug_assert_eq!(values.len(), 1);
98+
self.num_groups = self.num_groups.max(total_num_groups);
7099
let arr = values[0].as_primitive::<T>();
71-
72-
for (idx, group_idx) in group_indices.iter().enumerate() {
73-
if let Some(filter) = opt_filter
74-
&& !filter.value(idx)
75-
{
76-
continue;
77-
}
78-
if arr.is_valid(idx) {
79-
let value = arr.value(idx);
80-
self.values[*group_idx].insert(value);
81-
}
82-
}
83-
100+
accumulate(group_indices, arr, opt_filter, |group_idx, value| {
101+
self.seen.insert((group_idx, value));
102+
});
84103
Ok(())
85104
}
86105

87106
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
88-
let counts: Vec<i64> = emit_to
89-
.take_needed(&mut self.values)
90-
.iter()
91-
.map(|groups| groups.len() as i64)
92-
.collect();
93-
107+
let group_values = self.emit_to_values(emit_to);
108+
let counts: Vec<i64> = group_values.iter().map(|v| v.len() as i64).collect();
94109
Ok(Arc::new(Int64Array::from(counts)))
95110
}
96111

97112
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
98-
let hash_sets = emit_to.take_needed(&mut self.values);
99-
let mut builder = ListBuilder::new(PrimitiveBuilder::<T>::new());
113+
let group_values = self.emit_to_values(emit_to);
100114

101-
for set in hash_sets {
102-
for value in set {
103-
builder.values().append_value(value);
104-
}
105-
builder.append(true);
115+
let mut offsets = vec![0i32];
116+
let mut all_values = Vec::new();
117+
for values in &group_values {
118+
all_values.extend(values.iter().copied());
119+
offsets.push(all_values.len() as i32);
106120
}
107-
Ok(vec![Arc::new(builder.finish())])
121+
122+
let values_array = Arc::new(PrimitiveArray::<T>::from_iter_values(all_values));
123+
let list_array = ListArray::new(
124+
Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
125+
OffsetBuffer::new(offsets.into()),
126+
values_array,
127+
None,
128+
);
129+
130+
Ok(vec![Arc::new(list_array)])
108131
}
109132

110133
fn merge_batch(
@@ -114,26 +137,23 @@ where
114137
_opt_filter: Option<&BooleanArray>,
115138
total_num_groups: usize,
116139
) -> datafusion_common::Result<()> {
117-
self.values.resize_with(total_num_groups, HashSet::default);
140+
debug_assert_eq!(values.len(), 1);
141+
self.num_groups = self.num_groups.max(total_num_groups);
118142
let list_array = values[0].as_list::<i32>();
119143

120144
for (row_idx, group_idx) in group_indices.iter().enumerate() {
121145
let inner = list_array.value(row_idx);
122-
let inner_set = inner.as_primitive::<T>();
123-
for i in 0..inner.len() {
124-
self.values[*group_idx].insert(inner_set.value(i));
146+
let inner_arr = inner.as_primitive::<T>();
147+
for value in inner_arr.values().iter() {
148+
self.seen.insert((*group_idx, *value));
125149
}
126150
}
151+
127152
Ok(())
128153
}
129154

130155
fn size(&self) -> usize {
131156
size_of::<Self>()
132-
+ self.values.capacity() * size_of::<HashSet<T::Native, RandomState>>()
133-
+ self
134-
.values
135-
.iter()
136-
.map(|s| s.capacity() * size_of::<T::Native>())
137-
.sum::<usize>()
157+
+ self.seen.capacity() * (size_of::<(usize, T::Native)>() + size_of::<u64>())
138158
}
139159
}

0 commit comments

Comments
 (0)