Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/functions-nested/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ paste = "1.0.14"
[dev-dependencies]
criterion = { workspace = true, features = ["async_tokio"] }
rand = { workspace = true }
rstest = { workspace = true }

[[bench]]
harness = false
Expand Down
63 changes: 53 additions & 10 deletions datafusion/functions-nested/src/set_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ use arrow::datatypes::{DataType, Field, FieldRef};
use arrow::row::{RowConverter, SortField};
use datafusion_common::cast::{as_large_list_array, as_list_array};
use datafusion_common::utils::ListCoercion;
use datafusion_common::{
exec_err, internal_err, plan_err, utils::take_function_args, Result,
};
use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
Expand Down Expand Up @@ -289,13 +287,7 @@ impl ScalarUDFImpl for ArrayDistinct {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match &arg_types[0] {
List(field) => Ok(DataType::new_list(field.data_type().clone(), true)),
LargeList(field) => {
Ok(DataType::new_large_list(field.data_type().clone(), true))
}
arg_type => plan_err!("{} does not support type {arg_type}", self.name()),
}
Ok(arg_types[0].clone())
}

fn invoke_with_args(
Expand Down Expand Up @@ -563,3 +555,54 @@ fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
array.nulls().cloned(),
)?))
}

#[cfg(test)]
mod tests {
use rstest::*;
use std::sync::Arc;

use arrow::{
array::{Int32Array, ListArray},
buffer::OffsetBuffer,
datatypes::{DataType, Field},
};
use datafusion_common::{config::ConfigOptions, DataFusionError};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};

use crate::set_ops::array_distinct_udf;

#[rstest(inner_nullable, case(true), case(false))]
#[test]
fn test_array_distinct_inner_nullability_result_type_match_return_type(
inner_nullable: bool,
) -> Result<(), DataFusionError> {
let udf = array_distinct_udf();

let inner_field = Field::new_list_field(DataType::Int32, inner_nullable);
let input_field = Field::new_list("input", Arc::new(inner_field.clone()), true);

// [[1, 1, 2]]
let input_array = ListArray::new(
inner_field.into(),
OffsetBuffer::new(vec![0, 3].into()),
Arc::new(Int32Array::new(vec![1, 1, 2].into(), None)),
None,
);
let input_array = ColumnarValue::Array(Arc::new(input_array));

let result = udf.invoke_with_args(ScalarFunctionArgs {
args: vec![input_array],
arg_fields: vec![input_field.clone().into()],
number_rows: 1,
return_field: input_field.clone().into(),
config_options: Arc::new(ConfigOptions::default()),
})?;

assert_eq!(
result.data_type(),
udf.return_type(&[input_field.data_type().clone()])?
);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScalarUdf::invoke_with_args has a data type check in debug mode:

#[cfg(debug_assertions)]
{
if &result.data_type() != return_field.data_type() {
return datafusion_common::internal_err!("Function '{}' returned value of type '{:?}' while the following type was promised at planning time and expected: '{:?}'",
self.name(),
result.data_type(),
return_field.data_type()
);
}

But I think it is better to have an explicit assertion in the test code.


Ok(())
}
}