Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add expression array_size #1122

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion docs/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
- [x] ~

### collection_funcs
- [ ] array_size
- [x] array_size
- [ ] cardinality
- [ ] concat
- [x] reverse
Expand Down
12 changes: 10 additions & 2 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ use datafusion_comet_proto::{
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
};
use datafusion_comet_spark_expr::{
ArrayInsert, Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField,
HourExpr, IfExpr, ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson,
ArrayInsert, ArraySize, Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields,
GetStructField, HourExpr, IfExpr, ListExtract, MinuteExpr, RLike, SecondExpr,
TimestampTruncExpr, ToJson,
};
use datafusion_common::scalar::ScalarStructBuilder;
use datafusion_common::{
Expand Down Expand Up @@ -736,6 +737,13 @@ impl PhysicalPlanner {
expr.legacy_negative_index,
)))
}
ExprStruct::ArraySize(expr) => {
let child = self.create_expr(
expr.src_array_expr.as_ref().unwrap(),
Arc::clone(&input_schema),
)?;
Ok(Arc::new(ArraySize::new(child)))
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
5 changes: 5 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ message Expr {
GetArrayStructFields get_array_struct_fields = 57;
BinaryExpr array_append = 58;
ArrayInsert array_insert = 59;
ArraySize array_size= 60;
}
}

Expand Down Expand Up @@ -412,6 +413,10 @@ message ArrayInsert {
bool legacy_negative_index = 4;
}

message ArraySize {
Expr src_array_expr = 1;
}

message DataType {
enum DataTypeId {
BOOL = 0;
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub mod utils;
pub use cast::{spark_cast, Cast};
pub use error::{SparkError, SparkResult};
pub use if_expr::IfExpr;
pub use list::{ArrayInsert, GetArrayStructFields, ListExtract};
pub use list::{ArrayInsert, ArraySize, GetArrayStructFields, ListExtract};
pub use regexp::RLike;
pub use structs::{CreateNamedStruct, GetStructField};
pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr};
Expand Down
86 changes: 86 additions & 0 deletions native/spark-expr/src/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,92 @@ impl PartialEq<dyn Any> for ArrayInsert {
}
}

#[derive(Debug, Hash)]
pub struct ArraySize {
src_array_expr: Arc<dyn PhysicalExpr>,
}

impl ArraySize {
pub fn new(src_array_expr: Arc<dyn PhysicalExpr>) -> Self {
Self { src_array_expr }
}
}

impl Display for ArraySize {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "ArraySize [array: {:?}]", self.src_array_expr)
}
}

impl PartialEq<dyn Any> for ArraySize {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.src_array_expr.eq(&x.src_array_expr))
.unwrap_or(false)
}
}

impl PhysicalExpr for ArraySize {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, _input_schema: &Schema) -> DataFusionResult<DataType> {
Ok(DataType::Int32)
}

fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
// Only non-nullable if fail_on_error is enabled and the element is non-nullable
self.src_array_expr.nullable(input_schema)
}

fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
let array_value = self
.src_array_expr
.evaluate(batch)?
.into_array(batch.num_rows())?;
match array_value.data_type() {
DataType::List(_) => {
let list_array = as_list_array(&array_value)?;
let mut builder = Int32Array::builder(list_array.len());
for i in 0..list_array.len() {
if list_array.is_null(i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to handle legacySizeOfNull here

builder.append_null();
} else {
builder.append_value(list_array.value_length(i));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will value_length(i) include nulls in the array at index i?

}
}
let sizes_array = Int32Array::from(builder.finish());
Ok(ColumnarValue::Array(Arc::new(sizes_array)))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a more efficient way to do this?

}
_ => Err(DataFusionError::Internal(format!(
"Unexpected data type in ArraySize: {:?}",
array_value.data_type()
))),
}
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.src_array_expr]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
match children.len() {
1 => Ok(Arc::new(ArraySize::new(Arc::clone(&children[0])))),
_ => internal_err!("ListExtract should have exactly two children"),
}
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.src_array_expr.hash(&mut s);
self.hash(&mut s);
}
}

#[cfg(test)]
mod test {
use crate::list::{array_insert, list_extract, zero_based_index};
Expand Down
10 changes: 10 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2220,6 +2220,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
expr.children(2))
None
}
case expr if expr.prettyName == "size" =>
val srcExprProto = exprToProto(expr.children(0), inputs, binding)
val arraySizeBuilder = ExprOuterClass.ArraySize
.newBuilder()
.setSrcArrayExpr(srcExprProto.get)
Some(
ExprOuterClass.Expr
.newBuilder()
.setArraySize(arraySizeBuilder)
.build())

case ElementAt(child, ordinal, defaultValue, failOnError)
if child.dataType.isInstanceOf[ArrayType] =>
Expand Down
17 changes: 17 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2390,4 +2390,21 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
checkSparkAnswer(df.select("arrUnsupportedArgs"))
}
}

test("ArraySize") {
assume(isSpark34Plus)
Seq(true, false).foreach(dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
val df = spark.read
.parquet(path.toString)
.withColumn("arr", array(col("_4"), lit(null), col("_4")))
.withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)"))
.withColumn("arrSizeResult", expr("array_size(arrInsertResult)"))
.withColumn("arrSizeResultNull", expr("array_size(null)"))
checkSparkAnswerAndOperator(df.select("arrSizeResult"))
checkSparkAnswerAndOperator(df.select("arrSizeResultNull"))
})
}
}
Loading