diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 555767f8f070..53642bf1622b 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -29,9 +29,7 @@ use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::utils::ListCoercion; -use datafusion_common::{ - exec_err, internal_err, plan_err, utils::take_function_args, Result, -}; +use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -289,13 +287,7 @@ impl ScalarUDFImpl for ArrayDistinct { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match &arg_types[0] { - List(field) => Ok(DataType::new_list(field.data_type().clone(), true)), - LargeList(field) => { - Ok(DataType::new_large_list(field.data_type().clone(), true)) - } - arg_type => plan_err!("{} does not support type {arg_type}", self.name()), - } + Ok(arg_types[0].clone()) } fn invoke_with_args( @@ -563,3 +555,54 @@ fn general_array_distinct( array.nulls().cloned(), )?)) } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + array::{Int32Array, ListArray}, + buffer::OffsetBuffer, + datatypes::{DataType, Field}, + }; + use datafusion_common::{config::ConfigOptions, DataFusionError}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; + + use crate::set_ops::array_distinct_udf; + + #[test] + fn test_array_distinct_inner_nullability_result_type_match_return_type( + ) -> Result<(), DataFusionError> { + let udf = array_distinct_udf(); + + for inner_nullable in [true, false] { + let inner_field = Field::new_list_field(DataType::Int32, inner_nullable); + let input_field = + Field::new_list("input", Arc::new(inner_field.clone()), true); + + // [[1, 1, 2]] + let input_array = ListArray::new( + inner_field.into(), + OffsetBuffer::new(vec![0, 3].into()), + Arc::new(Int32Array::new(vec![1, 1, 2].into(), None)), + None, + ); + + let input_array = ColumnarValue::Array(Arc::new(input_array)); + + let result = udf.invoke_with_args(ScalarFunctionArgs { + args: vec![input_array], + arg_fields: vec![input_field.clone().into()], + number_rows: 1, + return_field: input_field.clone().into(), + config_options: Arc::new(ConfigOptions::default()), + })?; + + assert_eq!( + result.data_type(), + udf.return_type(&[input_field.data_type().clone()])? + ); + } + Ok(()) + } +}