Skip to content

Commit 09aba20

Browse files
authored
Clean expression stack before resolving project list (#1059)
Signed-off-by: Lantao Jin <[email protected]>
1 parent c252899 commit 09aba20

File tree

8 files changed

+185
-16
lines changed

8 files changed

+185
-16
lines changed

integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala

+19-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ package org.opensearch.flint.spark.ppl
88
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
99
import org.apache.spark.sql.catalyst.TableIdentifier
1010
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
11-
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, EqualTo, IsNotNull, Literal, Not, SortOrder}
11+
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, EqualTo, GreaterThanOrEqual, IsNotNull, Literal, Not, SortOrder}
1212
import org.apache.spark.sql.catalyst.plans.logical._
1313
import org.apache.spark.sql.execution.ExplainMode
1414
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExplainCommand}
@@ -671,4 +671,22 @@ class FlintSparkPPLBasicITSuite
671671

672672
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
673673
}
674+
675+
test("test fields with alias") {
676+
val frame = sql(s"""
677+
| source = $testTable as l | where l.age >= 30 | fields l.name, l.age
678+
| """.stripMargin)
679+
680+
val logicalPlan: LogicalPlan = frame.queryExecution.logical
681+
val expectedPlan = Project(
682+
Seq(UnresolvedAttribute("l.name"), UnresolvedAttribute("l.age")),
683+
Filter(
684+
GreaterThanOrEqual(UnresolvedAttribute("l.age"), Literal(30)),
685+
SubqueryAlias(
686+
"l",
687+
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))))
688+
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
689+
assertSameRows(Seq(Row("Jake", 70), Row("Hello", 30)), frame)
690+
691+
}
674692
}

integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala

+41
Original file line numberDiff line numberDiff line change
@@ -475,4 +475,45 @@ class FlintSparkPPLInSubqueryITSuite
475475
assert(ex.getMessage.contains(
476476
"The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery"))
477477
}
478+
479+
test("test in subquery with table alias") {
480+
val frame = sql(s"""
481+
| source = $outerTable as o
482+
| | where id in [
483+
| source = $innerTable as i
484+
| | where i.department = 'DATA'
485+
| | fields i.uid
486+
| ]
487+
| | sort - o.salary
488+
| | fields o.id, o.name, o.salary
489+
| """.stripMargin)
490+
val expectedResults: Array[Row] = Array(Row(1002, "John", 120000), Row(1005, "Jane", 90000))
491+
assertSameRows(expectedResults, frame)
492+
493+
val logicalPlan: LogicalPlan = frame.queryExecution.logical
494+
val outer =
495+
SubqueryAlias("o", UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")))
496+
val inner =
497+
SubqueryAlias("i", UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")))
498+
val inSubquery =
499+
Filter(
500+
InSubquery(
501+
Seq(UnresolvedAttribute("id")),
502+
ListQuery(
503+
Project(
504+
Seq(UnresolvedAttribute("i.uid")),
505+
Filter(EqualTo(UnresolvedAttribute("i.department"), Literal("DATA")), inner)))),
506+
outer)
507+
val sortedPlan: LogicalPlan =
508+
Sort(Seq(SortOrder(UnresolvedAttribute("o.salary"), Descending)), global = true, inSubquery)
509+
val expectedPlan =
510+
Project(
511+
Seq(
512+
UnresolvedAttribute("o.id"),
513+
UnresolvedAttribute("o.name"),
514+
UnresolvedAttribute("o.salary")),
515+
sortedPlan)
516+
517+
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
518+
}
478519
}

integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala

+99
Original file line numberDiff line numberDiff line change
@@ -481,4 +481,103 @@ class FlintSparkPPLScalarSubqueryITSuite
481481
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0))
482482
assert(results.sorted.sameElements(expectedResults.sorted))
483483
}
484+
485+
test("test nested scalar subquery with table alias") {
486+
val frame = sql(s"""
487+
| source = $outerTable as o
488+
| | where id = [
489+
| source = $innerTable as i
490+
| | where uid = [
491+
| source = $nestedInnerTable as n
492+
| | stats min(n.salary)
493+
| ] + 1000
494+
| | sort i.department
495+
| | stats max(i.uid)
496+
| ]
497+
| | fields o.id, o.name
498+
| """.stripMargin)
499+
val expectedResults: Array[Row] = Array(Row(1000, "Jake"))
500+
assertSameRows(expectedResults, frame)
501+
}
502+
503+
test("test correlated scalar subquery with table alias") {
504+
val frame = sql(s"""
505+
| source = $outerTable as o
506+
| | where id = [
507+
| source = $innerTable as i | where o.id = i.uid | stats max(i.uid)
508+
| ]
509+
| | fields o.id, o.name
510+
| """.stripMargin)
511+
val expectedResults: Array[Row] = Array(
512+
Row(1000, "Jake"),
513+
Row(1002, "John"),
514+
Row(1003, "David"),
515+
Row(1005, "Jane"),
516+
Row(1006, "Tommy"))
517+
assertSameRows(expectedResults, frame)
518+
519+
val logicalPlan: LogicalPlan = frame.queryExecution.logical
520+
521+
val outer =
522+
SubqueryAlias("o", UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")))
523+
val inner =
524+
SubqueryAlias("i", UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")))
525+
val aggregateExpressions = Seq(
526+
Alias(
527+
UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("i.uid")), isDistinct = false),
528+
"max(i.uid)")())
529+
val innerFilter =
530+
Filter(EqualTo(UnresolvedAttribute("o.id"), UnresolvedAttribute("i.uid")), inner)
531+
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, innerFilter)
532+
val scalarSubqueryExpr = ScalarSubquery(aggregatePlan)
533+
val outerFilter = Filter(EqualTo(UnresolvedAttribute("id"), scalarSubqueryExpr), outer)
534+
val expectedPlan =
535+
Project(Seq(UnresolvedAttribute("o.id"), UnresolvedAttribute("o.name")), outerFilter)
536+
537+
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
538+
}
539+
540+
test("test uncorrelated scalar subquery with table alias") {
541+
val frame = sql(s"""
542+
| source = $outerTable as o
543+
| | eval max_uid = [
544+
| source = $innerTable as i | where i.department = 'DATA' | stats max(i.uid)
545+
| ]
546+
| | fields o.id, o.name, max_uid
547+
| """.stripMargin)
548+
val expectedResults: Array[Row] = Array(
549+
Row(1000, "Jake", 1005),
550+
Row(1001, "Hello", 1005),
551+
Row(1002, "John", 1005),
552+
Row(1003, "David", 1005),
553+
Row(1004, "David", 1005),
554+
Row(1005, "Jane", 1005),
555+
Row(1006, "Tommy", 1005))
556+
assertSameRows(expectedResults, frame)
557+
558+
val logicalPlan: LogicalPlan = frame.queryExecution.logical
559+
560+
val outer =
561+
SubqueryAlias("o", UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")))
562+
val inner =
563+
SubqueryAlias("i", UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")))
564+
val aggregateExpressions = Seq(
565+
Alias(
566+
UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("i.uid")), isDistinct = false),
567+
"max(i.uid)")())
568+
val innerFilter =
569+
Filter(EqualTo(UnresolvedAttribute("i.department"), Literal("DATA")), inner)
570+
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, innerFilter)
571+
val scalarSubqueryExpr = Alias(ScalarSubquery(aggregatePlan), "max_uid")()
572+
val outerFilter = Project(Seq(UnresolvedStar(None), scalarSubqueryExpr), outer)
573+
val expectedPlan =
574+
Project(
575+
Seq(
576+
UnresolvedAttribute("o.id"),
577+
UnresolvedAttribute("o.name"),
578+
UnresolvedAttribute("max_uid")),
579+
outerFilter)
580+
581+
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
582+
}
484583
}

ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystExpressionVisitor.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ public Expression visitCase(Case node, CatalystPlanContext context) {
349349
)
350350
);
351351
}
352-
context.retainAllNamedParseExpressions(e -> e);
352+
context.resetNamedParseExpressions();
353353
}
354354
context.setNamedParseExpressions(initialNameExpressions);
355355
return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue)));
@@ -421,7 +421,7 @@ public Expression visitBetween(Between node, CatalystPlanContext context) {
421421
Expression value = analyze(node.getValue(), context);
422422
Expression lower = analyze(node.getLowerBound(), context);
423423
Expression upper = analyze(node.getUpperBound(), context);
424-
context.retainAllNamedParseExpressions(p -> p);
424+
context.resetNamedParseExpressions();
425425
return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(new GreaterThanOrEqual(value, lower), new LessThanOrEqual(value, upper)));
426426
}
427427

ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java

+9-6
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,11 @@
88
import lombok.Getter;
99
import org.apache.spark.sql.SparkSession;
1010
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
11-
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
1211
import org.apache.spark.sql.catalyst.expressions.Expression;
13-
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
1412
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
1513
import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias;
1614
import org.apache.spark.sql.catalyst.plans.logical.Union;
17-
import org.apache.spark.sql.types.Metadata;
1815
import org.opensearch.sql.ast.expression.UnresolvedExpression;
19-
import org.opensearch.sql.data.type.ExprType;
2016
import scala.collection.Iterator;
2117
import scala.collection.Seq;
2218

@@ -27,7 +23,6 @@
2723
import java.util.Stack;
2824
import java.util.function.BiFunction;
2925
import java.util.function.Function;
30-
import java.util.function.UnaryOperator;
3126
import java.util.stream.Collectors;
3227

3328
import static java.util.Collections.emptyList;
@@ -218,6 +213,14 @@ public <T> Seq<T> retainAllPlans(Function<LogicalPlan, T> transformFunction) {
218213
return plans;
219214
}
220215

216+
/**
217+
* Reset all expressions in stack,
218+
* generally use it after calling visitFirstChild() in visit methods.
219+
*/
220+
public void resetNamedParseExpressions() {
221+
getNamedParseExpressions().retainAll(emptyList());
222+
}
223+
221224
/**
222225
* retain all expressions and clear expression stack
223226
*
@@ -226,7 +229,7 @@ public <T> Seq<T> retainAllPlans(Function<LogicalPlan, T> transformFunction) {
226229
public <T> Seq<T> retainAllNamedParseExpressions(Function<Expression, T> transformFunction) {
227230
Seq<T> aggregateExpressions = seq(getNamedParseExpressions().stream()
228231
.map(transformFunction).collect(Collectors.toList()));
229-
getNamedParseExpressions().retainAll(emptyList());
232+
resetNamedParseExpressions();
230233
return aggregateExpressions;
231234
}
232235

ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java

+12-4
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,14 @@
8989
import scala.collection.Seq;
9090

9191
import java.util.ArrayList;
92+
import java.util.Collection;
9293
import java.util.HashSet;
9394
import java.util.List;
9495
import java.util.Objects;
9596
import java.util.Optional;
9697
import java.util.Set;
9798
import java.util.stream.Collectors;
99+
import java.util.stream.Stream;
98100

99101
import static java.util.Collections.emptyList;
100102
import static java.util.List.of;
@@ -196,7 +198,7 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) {
196198
public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {
197199
visitFirstChild(node, context);
198200
return context.apply( searchSide -> {
199-
context.retainAllNamedParseExpressions(p -> p);
201+
context.resetNamedParseExpressions();
200202
context.retainAllPlans(p -> p);
201203
LogicalPlan target;
202204
LogicalPlan lookupTable = node.getLookupRelation().accept(this, context);
@@ -257,7 +259,7 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {
257259

258260
LogicalPlan outputWithDropped = DataFrameDropColumns$.MODULE$.apply(seq(toDrop), target);
259261

260-
context.retainAllNamedParseExpressions(p -> p);
262+
context.resetNamedParseExpressions();
261263
context.retainAllPlans(p -> p);
262264
return outputWithDropped;
263265
});
@@ -304,7 +306,7 @@ public LogicalPlan visitAppendCol(AppendCol node, CatalystPlanContext context) {
304306
var subSearchWithRowNumber = getRowNumStarProjection(context, subSearch, TABLE_RHS);
305307

306308
context.withSubqueryAlias(subSearchWithRowNumber);
307-
context.retainAllNamedParseExpressions(p -> p);
309+
context.resetNamedParseExpressions();
308310
context.retainAllPlans(p -> p);
309311

310312
// Join both Main and Sub search with _ROW_NUMBER_ column
@@ -351,7 +353,7 @@ public LogicalPlan visitJoin(Join node, CatalystPlanContext context) {
351353
LogicalPlan right = node.getRight().accept(this, context);
352354
Optional<Expression> joinCondition = node.getJoinCondition()
353355
.map(c -> expressionAnalyzer.analyzeJoinCondition(c, context));
354-
context.retainAllNamedParseExpressions(p -> p);
356+
context.resetNamedParseExpressions();
355357
context.retainAllPlans(p -> p);
356358
return join(left, right, node.getJoinType(), joinCondition, node.getJoinHint());
357359
});
@@ -371,6 +373,8 @@ public LogicalPlan visitSubqueryAlias(SubqueryAlias node, CatalystPlanContext co
371373
@Override
372374
public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) {
373375
visitFirstChild(node, context);
376+
// clean before to go
377+
context.resetNamedParseExpressions();
374378
List<Expression> aggsExpList = visitExpressionList(node.getAggExprList(), context);
375379
List<Expression> groupExpList = visitExpressionList(node.getGroupExprList(), context);
376380
if (!groupExpList.isEmpty()) {
@@ -460,6 +464,9 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) {
460464
context.withProjectedFields(node.getProjectList());
461465
}
462466
LogicalPlan child = visitFirstChild(node, context);
467+
468+
// reset expression stack before resolving project
469+
context.resetNamedParseExpressions();
463470
visitExpressionList(node.getProjectList(), context);
464471

465472
// Create a projection list from the existing expressions
@@ -481,6 +488,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) {
481488
@Override
482489
public LogicalPlan visitSort(Sort node, CatalystPlanContext context) {
483490
visitFirstChild(node, context);
491+
context.resetNamedParseExpressions();
484492
visitFieldList(node.getSortList(), context);
485493
Seq<SortOrder> sortElements = context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp));
486494
return context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p));

ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/GeoIpCatalystLogicalPlanTranslator.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ static private LogicalPlan applyJoin(Expression ipAddress, CatalystPlanContext c
116116
UnresolvedAttribute$.MODULE$.apply(seq(GEOIP_TABLE_ALIAS,GEOIP_IPV4_COLUMN_NAME))
117117
)
118118
));
119-
context.retainAllNamedParseExpressions(p -> p);
119+
context.resetNamedParseExpressions();
120120
context.retainAllPlans(p -> p);
121121
return join(leftAlias,
122122
rightAlias,

ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ static Expression buildLookupMappingCondition(
7575
Expression equalTo = EqualTo$.MODULE$.apply(lookupNamedExpression, sourceNamedExpression);
7676
equiConditions.add(equalTo);
7777
}
78-
context.retainAllNamedParseExpressions(e -> e);
78+
context.resetNamedParseExpressions();
7979
return equiConditions.stream().reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null);
8080
}
8181

@@ -114,7 +114,7 @@ static List<NamedExpression> buildOutputProjectList(
114114
seq(new java.util.ArrayList<String>()));
115115
outputProjectList.add(output);
116116
}
117-
context.retainAllNamedParseExpressions(p -> p);
117+
context.resetNamedParseExpressions();
118118
return outputProjectList;
119119
}
120120

0 commit comments

Comments
 (0)