Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -954,13 +954,7 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf {
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let input_expr = acc_args
.exprs
.first()
.ok_or(exec_datafusion_err!("Expected one argument"))?;
let input_field = input_expr.return_field(acc_args.schema)?;

let double_output = input_field
let double_output = acc_args.expr_fields[0]
.metadata()
.get("modify_values")
.map(|v| v == "double_output")
Expand Down
9 changes: 9 additions & 0 deletions datafusion/ffi/src/udaf/accumulator_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
pub struct ForeignAccumulatorArgs {
pub return_field: FieldRef,
pub schema: Schema,
pub expr_fields: Vec<FieldRef>,
pub ignore_nulls: bool,
pub order_bys: Vec<PhysicalSortExpr>,
pub is_reversed: bool,
Expand Down Expand Up @@ -132,9 +133,15 @@ impl TryFrom<FFI_AccumulatorArgs> for ForeignAccumulatorArgs {

let exprs = parse_physical_exprs(&proto_def.expr, &task_ctx, &schema, &codex)?;

let expr_fields = exprs
.iter()
.map(|e| e.return_field(&schema))
.collect::<Result<Vec<_>, _>>()?;

Ok(Self {
return_field,
schema,
expr_fields,
ignore_nulls: proto_def.ignore_nulls,
order_bys,
is_reversed: value.is_reversed,
Expand All @@ -150,6 +157,7 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> {
Self {
return_field: Arc::clone(&value.return_field),
schema: &value.schema,
expr_fields: &value.expr_fields,
ignore_nulls: value.ignore_nulls,
order_bys: &value.order_bys,
is_reversed: value.is_reversed,
Expand All @@ -175,6 +183,7 @@ mod tests {
let orig_args = AccumulatorArgs {
return_field: Field::new("f", DataType::Float64, true).into(),
schema: &schema,
expr_fields: &[Field::new("a", DataType::Int32, true).into()],
ignore_nulls: false,
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
is_reversed: false,
Expand Down
2 changes: 2 additions & 0 deletions datafusion/ffi/src/udaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ mod tests {
let acc_args = AccumulatorArgs {
return_field: Field::new("f", DataType::Float64, true).into(),
schema: &schema,
expr_fields: &[Field::new("a", DataType::Float64, true).into()],
ignore_nulls: true,
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
is_reversed: false,
Expand Down Expand Up @@ -782,6 +783,7 @@ mod tests {
let acc_args = AccumulatorArgs {
return_field: Field::new("f", DataType::Float64, true).into(),
schema: &schema,
expr_fields: &[Field::new("a", DataType::Float64, true).into()],
ignore_nulls: true,
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
is_reversed: false,
Expand Down
6 changes: 5 additions & 1 deletion datafusion/functions-aggregate-common/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ pub struct AccumulatorArgs<'a> {
/// The return field of the aggregate function.
pub return_field: FieldRef,

/// The schema of the input arguments
/// Input schema to the aggregate function. If you need to check data type, nullability
/// or metadata of input arguments then you should use `expr_fields` below instead.
pub schema: &'a Schema,

/// Whether to ignore nulls.
Expand Down Expand Up @@ -67,6 +68,9 @@ pub struct AccumulatorArgs<'a> {

/// The physical expression of arguments the aggregate function takes.
pub exprs: &'a [Arc<dyn PhysicalExpr>],

/// Fields corresponding to each expr (same order & length).
pub expr_fields: &'a [FieldRef],
Comment on lines +72 to +73
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main change is here

}

impl AccumulatorArgs<'_> {
Expand Down
8 changes: 6 additions & 2 deletions datafusion/functions-aggregate/benches/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion};

fn prepare_group_accumulator() -> Box<dyn GroupsAccumulator> {
let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)]));
let expr = col("f", &schema).unwrap();
let accumulator_args = AccumulatorArgs {
return_field: Field::new("f", DataType::Int64, true).into(),
schema: &schema,
expr_fields: &[expr.return_field(&schema).unwrap()],
ignore_nulls: false,
order_bys: &[],
is_reversed: false,
name: "COUNT(f)",
is_distinct: false,
exprs: &[col("f", &schema).unwrap()],
exprs: &[expr],
};
let count_fn = Count::new();

Expand All @@ -56,15 +58,17 @@ fn prepare_accumulator() -> Box<dyn Accumulator> {
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
true,
)]));
let expr = col("f", &schema).unwrap();
let accumulator_args = AccumulatorArgs {
return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
schema: &schema,
expr_fields: &[expr.return_field(&schema).unwrap()],
ignore_nulls: false,
order_bys: &[],
is_reversed: false,
name: "COUNT(f)",
is_distinct: true,
exprs: &[col("f", &schema).unwrap()],
exprs: &[expr],
};
let count_fn = Count::new();

Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/benches/min_max_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ fn create_max_bytes_accumulator() -> Box<dyn GroupsAccumulator> {
max.create_groups_accumulator(AccumulatorArgs {
return_field: Arc::new(Field::new("value", DataType::Utf8, true)),
schema: &input_schema,
expr_fields: &[Field::new("value", DataType::Utf8, true).into()],
ignore_nulls: true,
order_bys: &[],
is_reversed: false,
Expand Down
3 changes: 2 additions & 1 deletion datafusion/functions-aggregate/benches/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ fn prepare_accumulator(data_type: &DataType) -> Box<dyn GroupsAccumulator> {
let field = Field::new("f", data_type.clone(), true).into();
let schema = Arc::new(Schema::new(vec![Arc::clone(&field)]));
let accumulator_args = AccumulatorArgs {
return_field: field,
return_field: Arc::clone(&field),
schema: &schema,
expr_fields: &[field],
ignore_nulls: false,
order_bys: &[],
is_reversed: false,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ impl AggregateUDFImpl for ApproxDistinct {
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let data_type = acc_args.expr_fields[0].data_type();

let accumulator: Box<dyn Accumulator> = match data_type {
// TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl AggregateUDFImpl for ApproxMedian {

Ok(Box::new(ApproxPercentileAccumulator::new(
0.5_f64,
acc_args.exprs[0].data_type(acc_args.schema)?,
acc_args.expr_fields[0].data_type().clone(),
)))
}

Expand Down
13 changes: 6 additions & 7 deletions datafusion/functions-aggregate/src/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ impl ApproxPercentileCont {
None
};

let data_type = args.exprs[0].data_type(args.schema)?;
let data_type = args.expr_fields[0].data_type();
let accumulator: ApproxPercentileAccumulator = match data_type {
t @ (DataType::UInt8
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
Expand All @@ -198,12 +198,11 @@ impl ApproxPercentileCont {
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64) => {
| DataType::Float64 => {
if let Some(max_size) = tdigest_max_size {
ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size)
}else{
ApproxPercentileAccumulator::new(percentile, t)

ApproxPercentileAccumulator::new_with_max_size(percentile, data_type.clone(), max_size)
} else {
ApproxPercentileAccumulator::new(percentile, data_type.clone())
}
}
other => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,28 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight {
Arc::clone(&acc_args.exprs[2]), // percentile
]
},
..acc_args
expr_fields: if acc_args.exprs.len() == 4 {
&[
Arc::clone(&acc_args.expr_fields[0]), // value
Arc::clone(&acc_args.expr_fields[2]), // percentile
Arc::clone(&acc_args.expr_fields[3]), // centroids
]
} else {
&[
Arc::clone(&acc_args.expr_fields[0]), // value
Arc::clone(&acc_args.expr_fields[2]), // percentile
]
},
// Unchanged below; we list each field explicitly in case we ever add more
// fields to AccumulatorArgs making it easier to see if changes are also
// needed here.
return_field: acc_args.return_field,
schema: acc_args.schema,
ignore_nulls: acc_args.ignore_nulls,
order_bys: acc_args.order_bys,
is_reversed: acc_args.is_reversed,
name: acc_args.name,
is_distinct: acc_args.is_distinct,
};
let approx_percentile_cont_accumulator =
self.approx_percentile_cont.create_accumulator(sub_args)?;
Expand Down
18 changes: 11 additions & 7 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ impl AggregateUDFImpl for ArrayAgg {
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let ignore_nulls =
acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?;
let field = &acc_args.expr_fields[0];
let data_type = field.data_type();
let ignore_nulls = acc_args.ignore_nulls && field.is_nullable();

if acc_args.is_distinct {
// Limitation similar to Postgres. The aggregation function can only mix
Expand All @@ -191,15 +191,15 @@ impl AggregateUDFImpl for ArrayAgg {
}
};
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
&data_type,
data_type,
sort_option,
ignore_nulls,
)?));
}

let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
return Ok(Box::new(ArrayAggAccumulator::try_new(
&data_type,
data_type,
ignore_nulls,
)?));
};
Expand All @@ -210,7 +210,7 @@ impl AggregateUDFImpl for ArrayAgg {
.collect::<Result<Vec<_>>>()?;

OrderSensitiveArrayAggAccumulator::try_new(
&data_type,
data_type,
&ordering_dtypes,
ordering,
self.is_input_pre_ordered,
Expand Down Expand Up @@ -802,6 +802,7 @@ mod tests {
use datafusion_common::cast::as_generic_string_array;
use datafusion_common::internal_err;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use std::sync::Arc;

Expand Down Expand Up @@ -1159,15 +1160,18 @@ mod tests {
}

fn build(&self) -> Result<Box<dyn Accumulator>> {
let expr = Arc::new(Column::new("col", 0));
let expr_field = expr.return_field(&self.schema)?;
ArrayAgg::default().accumulator(AccumulatorArgs {
return_field: Arc::clone(&self.return_field),
schema: &self.schema,
expr_fields: &[expr_field],
ignore_nulls: false,
order_bys: &self.order_bys,
is_reversed: false,
name: "",
is_distinct: self.distinct,
exprs: &[Arc::new(Column::new("col", 0))],
exprs: &[expr],
})
}

Expand Down
27 changes: 14 additions & 13 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ impl AggregateUDFImpl for Avg {
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let data_type = acc_args.expr_fields[0].data_type();
use DataType::*;

// instantiate specialized accumulator based for the type
if acc_args.is_distinct {
match (&data_type, acc_args.return_type()) {
match (data_type, acc_args.return_type()) {
// Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation
(Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),

Expand Down Expand Up @@ -362,12 +362,13 @@ impl AggregateUDFImpl for Avg {
) -> Result<Box<dyn GroupsAccumulator>> {
use DataType::*;

let data_type = args.exprs[0].data_type(args.schema)?;
let data_type = args.expr_fields[0].data_type();

// instantiate specialized accumulator based for the type
match (&data_type, args.return_field.data_type()) {
match (data_type, args.return_field.data_type()) {
(Float64, Float64) => {
Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
&data_type,
data_type,
args.return_field.data_type(),
|sum: f64, count: u64| Ok(sum / count as f64),
)))
Expand All @@ -386,7 +387,7 @@ impl AggregateUDFImpl for Avg {
move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);

Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
&data_type,
data_type,
args.return_field.data_type(),
avg_fn,
)))
Expand All @@ -405,7 +406,7 @@ impl AggregateUDFImpl for Avg {
move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);

Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
&data_type,
data_type,
args.return_field.data_type(),
avg_fn,
)))
Expand All @@ -424,7 +425,7 @@ impl AggregateUDFImpl for Avg {
move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);

Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
&data_type,
data_type,
args.return_field.data_type(),
avg_fn,
)))
Expand All @@ -445,7 +446,7 @@ impl AggregateUDFImpl for Avg {
};

Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
&data_type,
data_type,
args.return_field.data_type(),
avg_fn,
)))
Expand All @@ -459,31 +460,31 @@ impl AggregateUDFImpl for Avg {
DurationSecondType,
_,
>::new(
&data_type,
data_type,
args.return_type(),
avg_fn,
))),
TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
DurationMillisecondType,
_,
>::new(
&data_type,
data_type,
args.return_type(),
avg_fn,
))),
TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
DurationMicrosecondType,
_,
>::new(
&data_type,
data_type,
args.return_type(),
avg_fn,
))),
TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
DurationNanosecondType,
_,
>::new(
&data_type,
data_type,
args.return_type(),
avg_fn,
))),
Expand Down
Loading