diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs new file mode 100644 index 000000000000..0dcc58d5bb8e --- /dev/null +++ b/datafusion/spark/src/function/string/concat.rs @@ -0,0 +1,269 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::string::concat::ConcatFunc; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `concat` expression +/// +/// +/// Concatenates multiple input strings into a single string. +/// Returns NULL if any input is NULL. +/// +/// Differences with DataFusion concat: +/// - Support 0 arguments +/// - Return NULL if any input is NULL +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkConcat { + signature: Signature, +} + +impl Default for SparkConcat { + fn default() -> Self { + Self::new() + } +} + +impl SparkConcat { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::UserDefined, TypeSignature::Nullary], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkConcat { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "concat" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_concat(args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + // Accept any string types, including zero arguments + Ok(arg_types.to_vec()) + } +} + +/// Represents the null state for Spark concat +enum NullMaskResolution { + /// Return NULL as the result (e.g., scalar inputs with at least one NULL) + ReturnNull, + /// No null mask needed (e.g., all scalar inputs are non-NULL) + NoMask, + /// Null mask to apply for arrays + Apply(NullBuffer), +} + +/// Concatenates strings, returning NULL if any input is NULL +/// This is a Spark-specific wrapper around DataFusion's concat that returns NULL +/// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs. +fn spark_concat(args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args: arg_values, + arg_fields, + number_rows, + return_field, + config_options, + } = args; + + // Handle zero-argument case: return empty string + if arg_values.is_empty() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8( + Some(String::new()), + ))); + } + + // Step 1: Check for NULL mask in incoming args + let null_mask = compute_null_mask(&arg_values, number_rows)?; + + // If all scalars and any is NULL, return NULL immediately + if matches!(null_mask, NullMaskResolution::ReturnNull) { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + + // Step 2: Delegate to DataFusion's concat + let concat_func = ConcatFunc::new(); + let func_args = ScalarFunctionArgs { + args: arg_values, + arg_fields, + number_rows, + return_field, + config_options, + }; + let result = concat_func.invoke_with_args(func_args)?; + + // Step 3: Apply NULL mask to result + apply_null_mask(result, null_mask) +} + +/// Compute NULL mask for the arguments using NullBuffer::union +fn compute_null_mask( + args: &[ColumnarValue], + number_rows: usize, +) -> Result { + // Check if all arguments are scalars + let all_scalars = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + + if all_scalars { + // For scalars, check if any is NULL + for arg in args { + if let ColumnarValue::Scalar(scalar) = arg { + if scalar.is_null() { + return Ok(NullMaskResolution::ReturnNull); + } + } + } + // No NULLs in scalars + Ok(NullMaskResolution::NoMask) + } else { + // For arrays, compute NULL mask for each row using NullBuffer::union + let array_len = args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .unwrap_or(number_rows); + + // Convert all scalars to arrays for uniform processing + let arrays: Result> = args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len), + }) + .collect(); + let arrays = arrays?; + + // Use NullBuffer::union to combine all null buffers + let combined_nulls = arrays + .iter() + .map(|arr| arr.nulls()) + .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); + + match combined_nulls { + Some(nulls) => Ok(NullMaskResolution::Apply(nulls)), + None => Ok(NullMaskResolution::NoMask), + } + } +} + +/// Apply NULL mask to the result using NullBuffer::union +fn apply_null_mask( + result: ColumnarValue, + null_mask: NullMaskResolution, +) -> Result { + match (result, null_mask) { + // Scalar with ReturnNull mask means return NULL + (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + // Scalar without mask, return as-is + (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar), + // Array with NULL mask - use NullBuffer::union to combine nulls + (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => { + // Combine the result's existing nulls with our computed null mask + let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask)); + + // Create new array with combined nulls + let new_array = array + .into_data() + .into_builder() + .nulls(combined_nulls) + .build()?; + + Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array( + new_array, + )))) + } + // Array without NULL mask, return as-is + (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array), + // Edge cases that shouldn't happen in practice + (scalar, _) => Ok(scalar), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::function::utils::test::test_scalar_function; + use arrow::array::StringArray; + use arrow::datatypes::DataType; + use datafusion_common::Result; + + #[test] + fn test_concat_basic() -> Result<()> { + test_scalar_function!( + SparkConcat::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))), + ], + Ok(Some("SparkSQL")), + &str, + DataType::Utf8, + StringArray + ); + Ok(()) + } + + #[test] + fn test_concat_with_null() -> Result<()> { + test_scalar_function!( + SparkConcat::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + DataType::Utf8, + StringArray + ); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs index e83b696bc1ba..a0c76cfabeaf 100644 --- a/datafusion/spark/src/function/string/mod.rs +++ b/datafusion/spark/src/function/string/mod.rs @@ -17,6 +17,7 @@ pub mod ascii; pub mod char; +pub mod concat; pub mod ilike; pub mod like; pub mod luhn_check; @@ -27,6 +28,7 @@ use std::sync::Arc; make_udf_function!(ascii::SparkAscii, ascii); make_udf_function!(char::CharFunc, char); +make_udf_function!(concat::SparkConcat, concat); make_udf_function!(ilike::SparkILike, ilike); make_udf_function!(like::SparkLike, like); make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check); @@ -44,6 +46,11 @@ pub mod expr_fn { "Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).", arg1 )); + export_functions!(( + concat, + "Concatenates multiple input strings into a single string. Returns NULL if any input is NULL.", + args + )); export_functions!(( ilike, "Returns true if str matches pattern (case insensitive).", @@ -62,5 +69,5 @@ pub mod expr_fn { } pub fn functions() -> Vec> { - vec![ascii(), char(), ilike(), like(), luhn_check()] + vec![ascii(), char(), concat(), ilike(), like(), luhn_check()] } diff --git a/datafusion/sqllogictest/test_files/spark/string/concat.slt b/datafusion/sqllogictest/test_files/spark/string/concat.slt new file mode 100644 index 000000000000..0b796a54a69e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/concat.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT concat('Spark', 'SQL'); +---- +SparkSQL + +query T +SELECT concat('Spark', 'SQL', NULL); +---- +NULL + +query T +SELECT concat('', '1', '', '2'); +---- +12 + +query T +SELECT concat(); +---- +(empty) + +query T +SELECT concat(''); +---- +(empty) + + +query T +SELECT concat(a, b, c) from (select 'a' a, 'b' b, 'c' c union all select null a, 'b', 'c') order by 1 nulls last; +---- +abc +NULL \ No newline at end of file