Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add expression array_size #1122

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion docs/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
- [x] ~

### collection_funcs
- [ ] array_size
- [x] array_size
- [ ] cardinality
- [ ] concat
- [x] reverse
Expand Down
9 changes: 8 additions & 1 deletion native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ use datafusion_comet_proto::{
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
};
use datafusion_comet_spark_expr::{
ArrayInsert, Avg, AvgDecimal, BitwiseNotExpr, Cast, CheckOverflow, Contains, Correlation,
ArrayInsert, ArraySize, Avg, AvgDecimal, BitwiseNotExpr, Cast, CheckOverflow, Contains, Correlation,
Covariance, CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, GetStructField,
HourExpr, IfExpr, Like, ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, SecondExpr,
SparkCastOptions, StartsWith, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal,
Expand Down Expand Up @@ -745,6 +745,13 @@ impl PhysicalPlanner {
));
Ok(array_has_expr)
}
ExprStruct::ArraySize(expr) => {
let child = self.create_expr(
expr.src_array_expr.as_ref().unwrap(),
Arc::clone(&input_schema),
)?;
Ok(Arc::new(ArraySize::new(child)))
}
ExprStruct::ArrayRemove(expr) => {
let src_array_expr =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
Expand Down
5 changes: 5 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ message Expr {
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
BinaryExpr arrays_overlap = 64;
ArraySize array_size = 65;
}
}

Expand Down Expand Up @@ -417,6 +418,10 @@ message ArrayInsert {
bool legacy_negative_index = 4;
}

message ArraySize {
Expr src_array_expr = 1;
}

message ArrayJoin {
Expr array_expr = 1;
Expr delimiter_expr = 2;
Expand Down
95 changes: 95 additions & 0 deletions native/spark-expr/src/array_funcs/array_size.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use std::any::Any;
use std::fmt::{Display, Formatter};
use std::hash::Hash;
use std::sync::Arc;
use arrow_array::{Array, Int32Array, RecordBatch};
use arrow_schema::{DataType, Schema};
use datafusion::physical_expr::PhysicalExpr;
use datafusion_common::cast::as_list_array;
use datafusion_common::{internal_err, DataFusionError, Result as DataFusionResult};
use datafusion_expr_common::columnar_value::ColumnarValue;



#[derive(Debug, Eq)]
pub struct ArraySize {
src_array_expr: Arc<dyn PhysicalExpr>,
}

impl PartialEq for ArraySize {
fn eq(&self, other: &Self) -> bool {
self.src_array_expr.eq(&other.src_array_expr)
}
}

impl ArraySize {
pub fn new(src_array_expr: Arc<dyn PhysicalExpr>) -> Self {
Self { src_array_expr }
}
}

impl Display for ArraySize {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "ArraySize [array: {:?}]", self.src_array_expr)
}
}

impl Hash for ArraySize {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.src_array_expr.hash(state);
}
}

impl PhysicalExpr for ArraySize {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, _input_schema: &Schema) -> DataFusionResult<DataType> {
Ok(DataType::Int32)
}

fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
self.src_array_expr.nullable(input_schema)
}

fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
let array_value = self
.src_array_expr
.evaluate(batch)?
.into_array(batch.num_rows())?;
match array_value.data_type() {
DataType::List(_) => {
let list_array = as_list_array(&array_value)?;
let mut builder = Int32Array::builder(list_array.len());
for i in 0..list_array.len() {
if list_array.is_null(i) {
builder.append_null();
} else {
builder.append_value(list_array.value_length(i));
}
}
let sizes_array = Int32Array::from(builder.finish());
Ok(ColumnarValue::Array(Arc::new(sizes_array)))
}
_ => Err(DataFusionError::Internal(format!(
"Unexpected data type in ArraySize: {:?}",
array_value.data_type()
))),
}
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.src_array_expr]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
match children.len() {
1 => Ok(Arc::new(ArraySize::new(Arc::clone(&children[0])))),
_ => internal_err!("ArraySize should have exactly one child"),
}
}
}
2 changes: 2 additions & 0 deletions native/spark-expr/src/array_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
mod array_insert;
mod get_array_struct_fields;
mod list_extract;
mod array_size;

pub use array_insert::ArrayInsert;
pub use array_size::ArraySize;
pub use get_array_struct_fields::GetArrayStructFields;
pub use list_extract::ListExtract;
10 changes: 10 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2333,6 +2333,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
expr.children(2))
None
}
case Size(child, false) if child.dataType.isInstanceOf[ArrayType] =>
val srcExprProto = exprToProto(child, inputs, binding)
val arraySizeBuilder = ExprOuterClass.ArraySize
.newBuilder()
.setSrcArrayExpr(srcExprProto.get)
Some(
ExprOuterClass.Expr
.newBuilder()
.setArraySize(arraySizeBuilder)
.build())

case ElementAt(child, ordinal, defaultValue, failOnError)
if child.dataType.isInstanceOf[ArrayType] =>
Expand Down
112 changes: 111 additions & 1 deletion spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.sql.types.{Decimal, DecimalType}

import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus, isSpark40Plus}
import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus, isSpark35Plus, isSpark40Plus}

class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._
Expand Down Expand Up @@ -2601,4 +2601,114 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("array_append") {
assume(isSpark34Plus)
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
checkSparkAnswerAndOperator(spark.sql("Select array_append(array(_1),false) from t1"))
checkSparkAnswerAndOperator(
spark.sql("SELECT array_append(array(_2, _3, _4), 4) FROM t1"))
checkSparkAnswerAndOperator(
spark.sql("SELECT array_append(array(_2, _3, _4), null) FROM t1"));
checkSparkAnswerAndOperator(
spark.sql("SELECT array_append(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1"));
checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_8), 'test') FROM t1"));
checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_19), _19) FROM t1"));
checkSparkAnswerAndOperator(
spark.sql("SELECT array_append((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
}
}
}

test("array_prepend") {
assume(isSpark35Plus) // in Spark 3.5 array_prepend is implemented via array_insert
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
checkSparkAnswerAndOperator(spark.sql("Select array_prepend(array(_1),false) from t1"))
checkSparkAnswerAndOperator(
spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1"))
checkSparkAnswerAndOperator(
spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM t1"));
checkSparkAnswerAndOperator(
spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1"));
checkSparkAnswerAndOperator(spark.sql("SELECT array_prepend(array(_8), 'test') FROM t1"));
checkSparkAnswerAndOperator(spark.sql("SELECT array_prepend(array(_19), _19) FROM t1"));
checkSparkAnswerAndOperator(
spark.sql("SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
}
}
}

test("ArrayInsert") {
assume(isSpark34Plus)
Seq(true, false).foreach(dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
val df = spark.read
.parquet(path.toString)
.withColumn("arr", array(col("_4"), lit(null), col("_4")))
.withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)"))
.withColumn("arrInsertNegativeIndexResult", expr("array_insert(arr, -1, 1)"))
.withColumn("arrPosGreaterThanSize", expr("array_insert(arr, 8, 1)"))
.withColumn("arrNegPosGreaterThanSize", expr("array_insert(arr, -8, 1)"))
.withColumn("arrInsertNone", expr("array_insert(arr, 1, null)"))
checkSparkAnswerAndOperator(df.select("arrInsertResult"))
checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult"))
checkSparkAnswerAndOperator(df.select("arrPosGreaterThanSize"))
checkSparkAnswerAndOperator(df.select("arrNegPosGreaterThanSize"))
checkSparkAnswerAndOperator(df.select("arrInsertNone"))
})
}

test("ArrayInsertUnsupportedArgs") {
// This test checks that the else branch in ArrayInsert
// mapping to the comet is valid and fallback to spark is working fine.
assume(isSpark34Plus)
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = false, 10000)
val df = spark.read
.parquet(path.toString)
.withColumn("arr", array(col("_4"), lit(null), col("_4")))
.withColumn("idx", udf((_: Int) => 1).apply(col("_4")))
.withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)"))
checkSparkAnswer(df.select("arrUnsupportedArgs"))
}
}

test("ArraySize") {
assume(isSpark34Plus)
Seq(true, false).foreach(dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
val df = spark.read
.parquet(path.toString)
.withColumn("arr", array(col("_4"), lit(null), col("_4")))
.withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)"))
.withColumn("arrSizeResult", expr("array_size(arrInsertResult)"))
.withColumn("arrSizeResultNull", expr("array_size(null)"))
checkSparkAnswerAndOperator(df.select("arrSizeResult"))
checkSparkAnswerAndOperator(df.select("arrSizeResultNull"))
})
}

test("array_contains") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = false, n = 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
checkSparkAnswerAndOperator(
spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
checkSparkAnswerAndOperator(
spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
}
}
}