Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 3 additions & 34 deletions native/spark-expr/src/array_funcs/array_repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::utils::make_scalar_function;
use arrow::array::{
new_null_array, Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData,
NullBufferBuilder, OffsetSizeTrait, UInt64Array,
Expand All @@ -25,48 +26,16 @@ use arrow::compute::cast;
use arrow::datatypes::DataType::{LargeList, List};
use arrow::datatypes::{DataType, Field};
use datafusion::common::cast::{as_large_list_array, as_list_array, as_uint64_array};
use datafusion::common::{exec_err, DataFusionError, ScalarValue};
use datafusion::common::{exec_err, DataFusionError};
use datafusion::logical_expr::ColumnarValue;
use std::sync::Arc;

pub fn make_scalar_function<F>(
inner: F,
) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue, DataFusionError>
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError>,
{
move |args: &[ColumnarValue]| {
// first, identify if any of the arguments is an Array. If yes, store its `len`,
// as any scalar will need to be converted to an array of len `len`.
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let is_scalar = len.is_none();

let args = ColumnarValue::values_to_arrays(args)?;

let result = (inner)(&args);

if is_scalar {
// If all inputs are scalar, keeps output as scalar
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
result.map(ColumnarValue::Scalar)
} else {
result.map(ColumnarValue::Array)
}
}
}

pub fn spark_array_repeat(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
make_scalar_function(spark_array_repeat_inner)(args)
}

/// Array_repeat SQL function
fn spark_array_repeat_inner(args: &[ArrayRef]) -> datafusion::common::Result<ArrayRef> {
fn spark_array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
let element = &args[0];
let count_array = &args[1];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,210 +15,177 @@
// specific language governing permissions and limitations
// under the License.

use crate::utils::make_scalar_function;
use arrow::array::builder::GenericStringBuilder;
use arrow::array::cast::as_dictionary_array;
use arrow::array::types::Int32Type;
use arrow::array::{make_array, Array, AsArray, DictionaryArray};
use arrow::array::{as_dictionary_array, make_array, Array, AsArray, DictionaryArray};
use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
use datafusion::common::{cast::as_generic_string_array, DataFusionError};
use datafusion::physical_plan::ColumnarValue;
use std::sync::Arc;

const SPACE: &str = " ";
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_read_side_padding2(args, false)
make_scalar_function(spark_read_side_padding_no_truncate)(args)
}

/// Custom `rpad` because DataFusion's `rpad` has differences in unicode handling
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
make_scalar_function(spark_read_side_padding_truncate)(args)
}

pub fn spark_read_side_padding_truncate(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
spark_read_side_padding2(args, true)
}

pub fn spark_read_side_padding_no_truncate(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
spark_read_side_padding2(args, false)
}

fn spark_read_side_padding2(
args: &[ColumnarValue],
args: &[ArrayRef],
truncate: bool,
) -> Result<ColumnarValue, DataFusionError> {
) -> Result<ArrayRef, DataFusionError> {
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
),
// Dictionary support required for SPARK-48498
DataType::Dictionary(_, value_type) => {
let dict = as_dictionary_array::<Int32Type>(array);
let col = if value_type.as_ref() == &DataType::Utf8 {
spark_read_side_padding_internal::<i32>(
dict.values(),
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
)?
} else {
spark_read_side_padding_internal::<i64>(
dict.values(),
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
)?
};
// col consists of an array, so arg of to_array() is not used. Can be anything
let values = col.to_array(0)?;
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
Ok(ColumnarValue::Array(make_array(result.into())))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad/read_side_padding",
))),
[array, array_int] => match array.data_type() {
DataType::Utf8 => {
spark_read_side_padding_space_internal::<i32>(array, truncate, array_int)
}
}
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))] =>
{
match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
string,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
string,
),
DataType::LargeUtf8 => {
spark_read_side_padding_space_internal::<i64>(array, truncate, array_int)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad/read_side_padding",
))),
},
[array, array_int, array_pad_string] => {
match (array.data_type(), array_pad_string.data_type()) {
(DataType::Utf8, DataType::Utf8) => {
spark_read_side_padding_internal::<i32, i32, i32>(
array,
truncate,
array_int,
array_pad_string,
)
}
(DataType::Utf8, DataType::LargeUtf8) => {
spark_read_side_padding_internal::<i32, i64, i64>(
array,
truncate,
array_int,
array_pad_string,
)
}
(DataType::LargeUtf8, DataType::Utf8) => {
spark_read_side_padding_internal::<i64, i32, i64>(
array,
truncate,
array_int,
array_pad_string,
)
}
(DataType::LargeUtf8, DataType::LargeUtf8) => {
spark_read_side_padding_internal::<i64, i64, i64>(
array,
truncate,
array_int,
array_pad_string,
)
}
// Dictionary support required for SPARK-48498
DataType::Dictionary(_, value_type) => {
(DataType::Dictionary(_, value_type), DataType::Utf8) => {
let dict = as_dictionary_array::<Int32Type>(array);
let col = if value_type.as_ref() == &DataType::Utf8 {
spark_read_side_padding_internal::<i32>(
let values = if value_type.as_ref() == &DataType::Utf8 {
spark_read_side_padding_internal::<i32, i32, i32>(
dict.values(),
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
array_int,
array_pad_string,
)?
} else {
spark_read_side_padding_internal::<i64>(
spark_read_side_padding_internal::<i64, i32, i64>(
dict.values(),
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
array_int,
array_pad_string,
)?
};
// col consists of an array, so arg of to_array() is not used. Can be anything
let values = col.to_array(0)?;
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
Ok(ColumnarValue::Array(make_array(result.into())))
Ok(make_array(result.into()))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad/read_side_padding",
))),
}
}
[ColumnarValue::Array(array), ColumnarValue::Array(array_int)] => match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
SPACE,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
SPACE,
),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad/read_side_padding",
))),
},
[ColumnarValue::Array(array), ColumnarValue::Array(array_int), ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))] => {
match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
string,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
string,
),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad/read_side_padding",
))),
}
}
other => Err(DataFusionError::Internal(format!(
"Unsupported arguments {other:?} for function rpad/read_side_padding",
))),
}
}

fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
fn spark_read_side_padding_space_internal<T: OffsetSizeTrait>(
array: &ArrayRef,
truncate: bool,
pad_type: ColumnarValue,
pad_string: &str,
) -> Result<ColumnarValue, DataFusionError> {
array_int: &ArrayRef,
) -> Result<ArrayRef, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
match pad_type {
ColumnarValue::Array(array_int) => {
let int_pad_array = array_int.as_primitive::<Int32Type>();
let int_pad_array = array_int.as_primitive::<Int32Type>();

let mut builder = GenericStringBuilder::<T>::with_capacity(
string_array.len(),
string_array.len() * int_pad_array.len(),
);
let mut builder = GenericStringBuilder::<T>::with_capacity(
string_array.len(),
string_array.len() * int_pad_array.len(),
);

for (string, length) in string_array.iter().zip(int_pad_array) {
match string {
Some(string) => builder.append_value(add_padding_string(
string.parse().unwrap(),
length.unwrap() as usize,
truncate,
pad_string,
)?),
_ => builder.append_null(),
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
for (string, length) in string_array.iter().zip(int_pad_array) {
match (string, length) {
(Some(string), Some(length)) => builder.append_value(add_padding_string(
string.parse().unwrap(),
length as usize,
truncate,
SPACE,
)?),
_ => builder.append_null(),
}
ColumnarValue::Scalar(const_pad_length) => {
let length = 0.max(i32::try_from(const_pad_length)?) as usize;
}
Ok(Arc::new(builder.finish()))
}

fn spark_read_side_padding_internal<T: OffsetSizeTrait, O: OffsetSizeTrait, S: OffsetSizeTrait>(
array: &ArrayRef,
truncate: bool,
array_int: &ArrayRef,
pad_string_array: &ArrayRef,
) -> Result<ArrayRef, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
let int_pad_array = array_int.as_primitive::<Int32Type>();
let pad_string_array = as_generic_string_array::<O>(pad_string_array)?;

let mut builder = GenericStringBuilder::<T>::with_capacity(
string_array.len(),
string_array.len() * length,
);
let mut builder = GenericStringBuilder::<S>::with_capacity(
string_array.len(),
string_array.len() * int_pad_array.len(),
);

for string in string_array.iter() {
match string {
Some(string) => builder.append_value(add_padding_string(
string.parse().unwrap(),
length,
truncate,
pad_string,
)?),
_ => builder.append_null(),
}
for ((string, length), pad_string) in string_array
.iter()
.zip(int_pad_array)
.zip(pad_string_array.iter())
{
match (string, length, pad_string) {
(Some(string), Some(length), Some(pad_string)) => {
builder.append_value(add_padding_string(
string.parse().unwrap(),
length as usize,
truncate,
pad_string,
)?)
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
_ => builder.append_null(),
}
}
Ok(Arc::new(builder.finish()))
}

fn add_padding_string(
Expand Down
Loading
Loading