Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -954,13 +954,7 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf {
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let input_expr = acc_args
.exprs
.first()
.ok_or(exec_datafusion_err!("Expected one argument"))?;
let input_field = input_expr.return_field(acc_args.schema)?;

let double_output = input_field
let double_output = acc_args.expr_fields[0]
.metadata()
.get("modify_values")
.map(|v| v == "double_output")
Expand Down
9 changes: 9 additions & 0 deletions datafusion/ffi/src/udaf/accumulator_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
pub struct ForeignAccumulatorArgs {
pub return_field: FieldRef,
pub schema: Schema,
pub expr_fields: Vec<FieldRef>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be a breaking FFI change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, since FFI_AccumulatorArgs seems to be the one marked as being stable across FFI boundaries:

/// A stable struct for sharing [`AccumulatorArgs`] across FFI boundaries.
/// For an explanation of each field, see the corresponding field
/// defined in [`AccumulatorArgs`].
#[repr(C)]
#[derive(Debug, StableAbi)]
#[allow(non_camel_case_types)]
pub struct FFI_AccumulatorArgs {
return_field: WrappedSchema,
schema: WrappedSchema,
is_reversed: bool,
name: RString,
physical_expr_def: RVec<u8>,
}

Though I am not familiar with the FFI related code.

Copy link
Contributor

@alamb alamb Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @timsaucer -- perhaps you can confirm this doesn't mess up the FFI API

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I've been slammed and didn't get to this soon. Short answer: No, this is not breaking. Long answer: It probably should.

The way this is set up it will continue to work but that is because we now have a duplicate piece of code here and here.

The FFI code is set up to try to mirror the non-ffi versions as much as possible. As long as one of these code paths linked above doesn't change then we should have no problem. It would probably be more robust to have those fields passed just like the other fields between the FFI versions.

We do have a unit test that does cover round tripping these arguments, so we are probably okay with this as is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having test coverage will certainly allow me to sleep better at night

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking this @timsaucer

Hopefully this existing test is sufficient for @alamb to have a nice sleep

#[test]
fn test_round_trip_accumulator_args() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let orig_args = AccumulatorArgs {
return_field: Field::new("f", DataType::Float64, true).into(),
schema: &schema,
expr_fields: &[Field::new("a", DataType::Int32, true).into()],
ignore_nulls: false,
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
is_reversed: false,
name: "round_trip",
is_distinct: true,
exprs: &[col("a", &schema)?],
};
let orig_str = format!("{orig_args:?}");
let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?;
let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?;
let round_trip_args: AccumulatorArgs = (&foreign_args).into();
let round_trip_str = format!("{round_trip_args:?}");
// Since AccumulatorArgs doesn't implement Eq, simply compare
// the debug strings.
assert_eq!(orig_str, round_trip_str);
Ok(())
}

pub ignore_nulls: bool,
pub order_bys: Vec<PhysicalSortExpr>,
pub is_reversed: bool,
Expand Down Expand Up @@ -132,9 +133,15 @@ impl TryFrom<FFI_AccumulatorArgs> for ForeignAccumulatorArgs {

let exprs = parse_physical_exprs(&proto_def.expr, &task_ctx, &schema, &codex)?;

let expr_fields = exprs
.iter()
.map(|e| e.return_field(&schema))
.collect::<Result<Vec<_>, _>>()?;

Ok(Self {
return_field,
schema,
expr_fields,
ignore_nulls: proto_def.ignore_nulls,
order_bys,
is_reversed: value.is_reversed,
Expand All @@ -150,6 +157,7 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> {
Self {
return_field: Arc::clone(&value.return_field),
schema: &value.schema,
expr_fields: &value.expr_fields,
ignore_nulls: value.ignore_nulls,
order_bys: &value.order_bys,
is_reversed: value.is_reversed,
Expand All @@ -175,6 +183,7 @@ mod tests {
let orig_args = AccumulatorArgs {
return_field: Field::new("f", DataType::Float64, true).into(),
schema: &schema,
expr_fields: &[Field::new("a", DataType::Int32, true).into()],
ignore_nulls: false,
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
is_reversed: false,
Expand Down
2 changes: 2 additions & 0 deletions datafusion/ffi/src/udaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ mod tests {
let acc_args = AccumulatorArgs {
return_field: Field::new("f", DataType::Float64, true).into(),
schema: &schema,
expr_fields: &[Field::new("a", DataType::Float64, true).into()],
ignore_nulls: true,
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
is_reversed: false,
Expand Down Expand Up @@ -782,6 +783,7 @@ mod tests {
let acc_args = AccumulatorArgs {
return_field: Field::new("f", DataType::Float64, true).into(),
schema: &schema,
expr_fields: &[Field::new("a", DataType::Float64, true).into()],
ignore_nulls: true,
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
is_reversed: false,
Expand Down
6 changes: 5 additions & 1 deletion datafusion/functions-aggregate-common/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ pub struct AccumulatorArgs<'a> {
/// The return field of the aggregate function.
pub return_field: FieldRef,

/// The schema of the input arguments
/// Input schema to the aggregate function. If you need to check data type, nullability
/// or metadata of input arguments then you should use `expr_fields` below instead.
pub schema: &'a Schema,

/// Whether to ignore nulls.
Expand Down Expand Up @@ -67,6 +68,9 @@ pub struct AccumulatorArgs<'a> {

/// The physical expression of arguments the aggregate function takes.
pub exprs: &'a [Arc<dyn PhysicalExpr>],

/// Fields corresponding to each expr (same order & length).
pub expr_fields: &'a [FieldRef],
Comment on lines +72 to +73
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main change is here

}

impl AccumulatorArgs<'_> {
Expand Down
8 changes: 6 additions & 2 deletions datafusion/functions-aggregate/benches/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion};

fn prepare_group_accumulator() -> Box<dyn GroupsAccumulator> {
let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)]));
let expr = col("f", &schema).unwrap();
let accumulator_args = AccumulatorArgs {
return_field: Field::new("f", DataType::Int64, true).into(),
schema: &schema,
expr_fields: &[expr.return_field(&schema).unwrap()],
ignore_nulls: false,
order_bys: &[],
is_reversed: false,
name: "COUNT(f)",
is_distinct: false,
exprs: &[col("f", &schema).unwrap()],
exprs: &[expr],
};
let count_fn = Count::new();

Expand All @@ -56,15 +58,17 @@ fn prepare_accumulator() -> Box<dyn Accumulator> {
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
true,
)]));
let expr = col("f", &schema).unwrap();
let accumulator_args = AccumulatorArgs {
return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
schema: &schema,
expr_fields: &[expr.return_field(&schema).unwrap()],
ignore_nulls: false,
order_bys: &[],
is_reversed: false,
name: "COUNT(f)",
is_distinct: true,
exprs: &[col("f", &schema).unwrap()],
exprs: &[expr],
};
let count_fn = Count::new();

Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/benches/min_max_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ fn create_max_bytes_accumulator() -> Box<dyn GroupsAccumulator> {
max.create_groups_accumulator(AccumulatorArgs {
return_field: Arc::new(Field::new("value", DataType::Utf8, true)),
schema: &input_schema,
expr_fields: &[Field::new("value", DataType::Utf8, true).into()],
ignore_nulls: true,
order_bys: &[],
is_reversed: false,
Expand Down
3 changes: 2 additions & 1 deletion datafusion/functions-aggregate/benches/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ fn prepare_accumulator(data_type: &DataType) -> Box<dyn GroupsAccumulator> {
let field = Field::new("f", data_type.clone(), true).into();
let schema = Arc::new(Schema::new(vec![Arc::clone(&field)]));
let accumulator_args = AccumulatorArgs {
return_field: field,
return_field: Arc::clone(&field),
schema: &schema,
expr_fields: &[field],
ignore_nulls: false,
order_bys: &[],
is_reversed: false,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ impl AggregateUDFImpl for ApproxDistinct {
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let data_type = acc_args.expr_fields[0].data_type();

let accumulator: Box<dyn Accumulator> = match data_type {
// TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl AggregateUDFImpl for ApproxMedian {

Ok(Box::new(ApproxPercentileAccumulator::new(
0.5_f64,
acc_args.exprs[0].data_type(acc_args.schema)?,
acc_args.expr_fields[0].data_type().clone(),
)))
}

Expand Down
13 changes: 6 additions & 7 deletions datafusion/functions-aggregate/src/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ impl ApproxPercentileCont {
None
};

let data_type = args.exprs[0].data_type(args.schema)?;
let data_type = args.expr_fields[0].data_type();
let accumulator: ApproxPercentileAccumulator = match data_type {
t @ (DataType::UInt8
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
Expand All @@ -198,12 +198,11 @@ impl ApproxPercentileCont {
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64) => {
| DataType::Float64 => {
if let Some(max_size) = tdigest_max_size {
ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size)
}else{
ApproxPercentileAccumulator::new(percentile, t)

ApproxPercentileAccumulator::new_with_max_size(percentile, data_type.clone(), max_size)
} else {
ApproxPercentileAccumulator::new(percentile, data_type.clone())
}
}
other => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,28 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight {
Arc::clone(&acc_args.exprs[2]), // percentile
]
},
..acc_args
expr_fields: if acc_args.exprs.len() == 4 {
&[
Arc::clone(&acc_args.expr_fields[0]), // value
Arc::clone(&acc_args.expr_fields[2]), // percentile
Arc::clone(&acc_args.expr_fields[3]), // centroids
]
} else {
&[
Arc::clone(&acc_args.expr_fields[0]), // value
Arc::clone(&acc_args.expr_fields[2]), // percentile
]
},
// Unchanged below; we list each field explicitly in case we ever add more
// fields to AccumulatorArgs making it easier to see if changes are also
// needed here.
return_field: acc_args.return_field,
schema: acc_args.schema,
ignore_nulls: acc_args.ignore_nulls,
order_bys: acc_args.order_bys,
is_reversed: acc_args.is_reversed,
name: acc_args.name,
is_distinct: acc_args.is_distinct,
};
let approx_percentile_cont_accumulator =
self.approx_percentile_cont.create_accumulator(sub_args)?;
Expand Down
18 changes: 11 additions & 7 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ impl AggregateUDFImpl for ArrayAgg {
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let ignore_nulls =
acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?;
let field = &acc_args.expr_fields[0];
let data_type = field.data_type();
let ignore_nulls = acc_args.ignore_nulls && field.is_nullable();

if acc_args.is_distinct {
// Limitation similar to Postgres. The aggregation function can only mix
Expand All @@ -191,15 +191,15 @@ impl AggregateUDFImpl for ArrayAgg {
}
};
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
&data_type,
data_type,
sort_option,
ignore_nulls,
)?));
}

let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
return Ok(Box::new(ArrayAggAccumulator::try_new(
&data_type,
data_type,
ignore_nulls,
)?));
};
Expand All @@ -210,7 +210,7 @@ impl AggregateUDFImpl for ArrayAgg {
.collect::<Result<Vec<_>>>()?;

OrderSensitiveArrayAggAccumulator::try_new(
&data_type,
data_type,
&ordering_dtypes,
ordering,
self.is_input_pre_ordered,
Expand Down Expand Up @@ -802,6 +802,7 @@ mod tests {
use datafusion_common::cast::as_generic_string_array;
use datafusion_common::internal_err;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use std::sync::Arc;

Expand Down Expand Up @@ -1159,15 +1160,18 @@ mod tests {
}

fn build(&self) -> Result<Box<dyn Accumulator>> {
let expr = Arc::new(Column::new("col", 0));
let expr_field = expr.return_field(&self.schema)?;
ArrayAgg::default().accumulator(AccumulatorArgs {
return_field: Arc::clone(&self.return_field),
schema: &self.schema,
expr_fields: &[expr_field],
ignore_nulls: false,
order_bys: &self.order_bys,
is_reversed: false,
name: "",
is_distinct: self.distinct,
exprs: &[Arc::new(Column::new("col", 0))],
exprs: &[expr],
})
}

Expand Down
27 changes: 14 additions & 13 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ impl AggregateUDFImpl for Avg {
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let data_type = acc_args.expr_fields[0].data_type();
use DataType::*;

// instantiate specialized accumulator based for the type
if acc_args.is_distinct {
match (&data_type, acc_args.return_type()) {
match (data_type, acc_args.return_type()) {
// Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation
(Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),

Expand Down Expand Up @@ -362,12 +362,13 @@ impl AggregateUDFImpl for Avg {
) -> Result<Box<dyn GroupsAccumulator>> {
use DataType::*;

let data_type = args.exprs[0].data_type(args.schema)?;
let data_type = args.expr_fields[0].data_type();

// instantiate specialized accumulator based for the type
match (&data_type, args.return_field.data_type()) {
match (data_type, args.return_field.data_type()) {
(Float64, Float64) => {
Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
&data_type,
data_type,
args.return_field.data_type(),
|sum: f64, count: u64| Ok(sum / count as f64),
)))
Expand All @@ -386,7 +387,7 @@ impl AggregateUDFImpl for Avg {
move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);

Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
&data_type,
data_type,
args.return_field.data_type(),
avg_fn,
)))
Expand All @@ -405,7 +406,7 @@ impl AggregateUDFImpl for Avg {
move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);

Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
&data_type,
data_type,
args.return_field.data_type(),
avg_fn,
)))
Expand All @@ -424,7 +425,7 @@ impl AggregateUDFImpl for Avg {
move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);

Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
&data_type,
data_type,
args.return_field.data_type(),
avg_fn,
)))
Expand All @@ -445,7 +446,7 @@ impl AggregateUDFImpl for Avg {
};

Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
&data_type,
data_type,
args.return_field.data_type(),
avg_fn,
)))
Expand All @@ -459,31 +460,31 @@ impl AggregateUDFImpl for Avg {
DurationSecondType,
_,
>::new(
&data_type,
data_type,
args.return_type(),
avg_fn,
))),
TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
DurationMillisecondType,
_,
>::new(
&data_type,
data_type,
args.return_type(),
avg_fn,
))),
TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
DurationMicrosecondType,
_,
>::new(
&data_type,
data_type,
args.return_type(),
avg_fn,
))),
TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
DurationNanosecondType,
_,
>::new(
&data_type,
data_type,
args.return_type(),
avg_fn,
))),
Expand Down
Loading