Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support array_except function #1343

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
17 changes: 17 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ use datafusion_functions_nested::array_has::array_has_any_udf;
use datafusion_functions_nested::concat::ArrayAppend;
use datafusion_functions_nested::remove::array_remove_all_udf;
use datafusion_functions_nested::set_ops::array_intersect_udf;
use datafusion_functions_nested::except::array_except_udf;
use datafusion_functions_nested::string::array_to_string_udf;
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};

Expand Down Expand Up @@ -829,6 +830,22 @@ impl PhysicalPlanner {
));
Ok(array_has_any_expr)
}
ExprStruct::ArrayExcept(expr) => {
let left_expr =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right_expr =
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
let return_type = left_expr.data_type(&input_schema)?;
let args = vec![Arc::clone(&left_expr), right_expr];

let array_except_expr = Arc::new(ScalarFunctionExpr::new(
"array_except",
array_except_udf(),
args,
return_type,
));
Ok(array_except_expr)
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ message Expr {
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
BinaryExpr arrays_overlap = 64;
BinaryExpr array_except = 65;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2387,6 +2387,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
case _: ArrayIntersect => convert(CometArrayIntersect)
case _: ArrayJoin => convert(CometArrayJoin)
case _: ArraysOverlap => convert(CometArraysOverlap)
case _ if expr.prettyName == "array_except" =>
convert(CometArrayExcept)
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
15 changes: 15 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,21 @@ object CometArraysOverlap extends CometExpressionSerde with IncompatExpr {
}
}

object CometArrayExcept extends CometExpressionSerde with CometExprShim {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
createBinaryExpr(
expr,
expr.children(0),
expr.children(1),
inputs,
binding,
(builder, binaryExpr) => builder.setArrayExcept(binaryExpr))
}
}

object CometArrayJoin extends CometExpressionSerde with IncompatExpr {
override def convert(
expr: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,89 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}
}

test("array_except - basic test (only integer values)") {
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")

checkSparkAnswerAndOperator(
sql("SELECT array_except(array(_2, _3, _4), array(_3, _4)) from t1"))
checkSparkAnswerAndOperator(sql("SELECT array_except(array(_18), array(_19)) from t1"))
checkSparkAnswerAndOperator(spark.sql(
"SELECT array_except((CASE WHEN _2 = _3 THEN array(_2, _3, _4) END), array(_4)) FROM t1"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_2 = _3 would be always true?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, it is necessary to add the condition is not null, thank you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw you added WHERE _2 IS NOT NULL but not sure how it helps... _2 = _3 will be always true is it intentional?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a test case for deduplication checking.

}
}
}
}

test("array_except - test all types (native Parquet reader)") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
ParquetGenerator.makeParquetFile(
random,
spark,
filename,
100,
DataGenOptions(
allowNull = true,
generateNegativeZero = true,
generateArray = false,
generateStruct = false,
generateMap = false))
}
val table = spark.read.parquet(filename)
table.createOrReplaceTempView("t1")
// test with array of each column
for (fieldName <- table.schema.fieldNames) {
sql(s"SELECT array($fieldName, $fieldName) as a, array($fieldName) as b FROM t1")
.createOrReplaceTempView("t2")
val df = sql("SELECT array_except(a, b) FROM t2")
checkSparkAnswerAndOperator(df)
}
}
}

test("array_except - test all types (convert from Parquet)") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val options = DataGenOptions(
allowNull = true,
generateNegativeZero = true,
generateArray = true,
generateStruct = true,
generateMap = false)
ParquetGenerator.makeParquetFile(random, spark, filename, 100, options)
}
withSQLConf(
CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false",
CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") {
val table = spark.read.parquet(filename)
table.createOrReplaceTempView("t1")
// test with array of each column
for (field <- table.schema.fields) {
val fieldName = field.name
sql(s"SELECT array($fieldName, $fieldName) as a, array($fieldName) as b FROM t1")
.createOrReplaceTempView("t2")
val df = sql("SELECT array_except(a, b) FROM t2")
field.dataType match {
case _: StructType =>
// skip due to https://github.com/apache/datafusion-comet/issues/1314
case _ =>
checkSparkAnswer(df)
}
}
}
}
}

}
Loading