Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support array_compact function #1321

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
19 changes: 19 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,25 @@ impl PhysicalPlanner {
));
Ok(array_join_expr)
}
ExprStruct::ArrayCompact(expr) => {
let src_array_expr =
self.create_expr(expr.array_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
let datatype = to_arrow_datatype(expr.item_datatype.as_ref().unwrap());

let null_literal_expr: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new(ScalarValue::Null.cast_to(&datatype)?));
let args = vec![Arc::clone(&src_array_expr), null_literal_expr];
let return_type = src_array_expr.data_type(&input_schema)?;

let array_compact_expr = Arc::new(ScalarFunctionExpr::new(
"array_compact",
array_remove_all_udf(),
args,
return_type,
));

Ok(array_compact_expr)
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
6 changes: 6 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ message Expr {
BinaryExpr array_remove = 61;
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
ArrayCompact array_compact = 64;
}
}

Expand Down Expand Up @@ -422,6 +423,11 @@ message ArrayJoin {
Expr null_replacement_expr = 3;
}

message ArrayCompact {
Expr array_expr = 1;
DataType item_datatype = 2;
}

message DataType {
enum DataTypeId {
BOOL = 0;
Expand Down
16 changes: 16 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2428,6 +2428,22 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
None
}
case expr @ ArrayFilter(child, _) if ArrayCompact(child).replacement.sql == expr.sql =>
val elementType = serializeDataType(child.dataType.asInstanceOf[ArrayType].elementType)
val srcExprProto = exprToProto(child, inputs, binding)
if (elementType.isDefined && srcExprProto.isDefined) {
val arrayCompactBuilder = ExprOuterClass.ArrayCompact
.newBuilder()
.setArrayExpr(srcExprProto.get)
.setItemDatatype(elementType.get)
Some(
ExprOuterClass.Expr
.newBuilder()
.setArrayCompact(arrayCompactBuilder)
.build())
} else {
None
}
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
16 changes: 16 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2701,4 +2701,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("array_compact") {
assume(isSpark34Plus)
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, n = 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")

checkSparkAnswerAndOperator(
sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NULL"))
checkSparkAnswerAndOperator(
sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NOT NULL"))
}
}
}
}