Skip to content

Commit ae76067

Browse files
committed
UT refactor
Signed-off-by: Lantao Jin <[email protected]>
1 parent 10a4974 commit ae76067

5 files changed

+61
-127
lines changed

ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala

+5-30
Original file line numberDiff line numberDiff line change
@@ -1106,38 +1106,13 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
11061106
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
11071107
}
11081108

1109-
test("test average backticks price") {
1110-
val context = new CatalystPlanContext
1111-
val logPlan =
1112-
planTransformer.visit(plan(pplParser, "source = table | stats avg(`price`)"), context)
1113-
val star = Seq(UnresolvedStar(None))
1114-
1115-
val priceField = UnresolvedAttribute("price")
1116-
val tableRelation = UnresolvedRelation(Seq("table"))
1117-
val aggregateExpressions = Seq(
1118-
Alias(
1119-
UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false),
1120-
"avg(`price`)")())
1121-
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation)
1122-
val expectedPlan = Project(star, aggregatePlan)
1123-
1124-
comparePlans(expectedPlan, logPlan, false)
1125-
}
1126-
11271109
test("test average price with backticks alias") {
1128-
val context = new CatalystPlanContext
1110+
val expectedPlan = planTransformer.visit(
1111+
plan(pplParser, "source = table | stats avg(price) as avg_price"),
1112+
new CatalystPlanContext)
11291113
val logPlan = planTransformer.visit(
1130-
plan(pplParser, "source = table | stats avg(price) as `avg_price`"),
1131-
context)
1132-
val star = Seq(UnresolvedStar(None))
1133-
1134-
val priceField = UnresolvedAttribute("price")
1135-
val tableRelation = UnresolvedRelation(Seq("table"))
1136-
val aggregateExpressions = Seq(
1137-
Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg_price")())
1138-
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation)
1139-
val expectedPlan = Project(star, aggregatePlan)
1140-
1114+
plan(pplParser, "source = table | stats avg(`price`) as `avg_price`"),
1115+
new CatalystPlanContext)
11411116
comparePlans(expectedPlan, logPlan, false)
11421117
}
11431118
}

ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala

+11-26
Original file line numberDiff line numberDiff line change
@@ -395,39 +395,24 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
395395
}
396396

397397
test("Search multiple tables - with backticks table alias") {
398-
val context = new CatalystPlanContext
398+
val expectedPlan =
399+
planTransformer.visit(
400+
plan(
401+
pplParser,
402+
"""
403+
| source=table1, table2, table3 as t
404+
| | where t.name = 'Molly'
405+
|""".stripMargin),
406+
new CatalystPlanContext)
399407
val logPlan =
400408
planTransformer.visit(
401409
plan(
402410
pplParser,
403411
"""
404412
| source=table1, table2, table3 as `t`
405-
| | where t.name = 'Molly'
413+
| | where `t`.`name` = 'Molly'
406414
|""".stripMargin),
407-
context)
408-
409-
val table1 = UnresolvedRelation(Seq("table1"))
410-
val table2 = UnresolvedRelation(Seq("table2"))
411-
val table3 = UnresolvedRelation(Seq("table3"))
412-
val star = UnresolvedStar(None)
413-
val plan1 = Project(
414-
Seq(star),
415-
Filter(
416-
EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")),
417-
SubqueryAlias("t", table1)))
418-
val plan2 = Project(
419-
Seq(star),
420-
Filter(
421-
EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")),
422-
SubqueryAlias("t", table2)))
423-
val plan3 = Project(
424-
Seq(star),
425-
Filter(
426-
EqualTo(UnresolvedAttribute("t.name"), Literal("Molly")),
427-
SubqueryAlias("t", table3)))
428-
429-
val expectedPlan =
430-
Union(Seq(plan1, plan2, plan3), byName = true, allowMissingCol = true)
415+
new CatalystPlanContext)
431416

432417
comparePlans(expectedPlan, logPlan, false)
433418
}

ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala

+34-43
Original file line numberDiff line numberDiff line change
@@ -847,71 +847,62 @@ class PPLLogicalPlanJoinTranslatorTestSuite
847847
}
848848

849849
test("test multiple joins with table and subquery backticks alias") {
850-
val context = new CatalystPlanContext
851-
val logPlan = plan(
850+
val originPlan = plan(
852851
pplParser,
853852
s"""
854853
| source = table1 as t1
855-
| | JOIN left = `l` right = `r` ON t1.id = t2.id
854+
| | JOIN left = l right = r ON t1.id = t2.id
855+
| [
856+
| source = table2 as t2
857+
| ]
858+
| | JOIN left = l right = r ON t2.id = t3.id
859+
| [
860+
| source = table3 as t3
861+
| ]
862+
| | JOIN left = l right = r ON t3.id = t4.id
863+
| [
864+
| source = table4 as t4
865+
| ]
866+
| """.stripMargin)
867+
val expectedPlan = planTransformer.visit(originPlan, new CatalystPlanContext)
868+
val logPlan = plan(
869+
pplParser,
870+
s"""
871+
| source = table1 as `t1`
872+
| | JOIN left = `l` right = `r` ON `t1`.`id` = `t2`.`id`
856873
| [
857874
| source = table2 as `t2`
858875
| ]
859-
| | JOIN left = `l` right = `r` ON t2.id = t3.id
876+
| | JOIN left = `l` right = `r` ON `t2`.`id` = `t3`.`id`
860877
| [
861878
| source = table3 as `t3`
862879
| ]
863-
| | JOIN left = `l` right = `r` ON t3.id = t4.id
880+
| | JOIN left = `l` right = `r` ON `t3`.`id` = `t4`.`id`
864881
| [
865882
| source = table4 as `t4`
866883
| ]
867884
| """.stripMargin)
868-
val logicalPlan = planTransformer.visit(logPlan, context)
869-
val table1 = UnresolvedRelation(Seq("table1"))
870-
val table2 = UnresolvedRelation(Seq("table2"))
871-
val table3 = UnresolvedRelation(Seq("table3"))
872-
val table4 = UnresolvedRelation(Seq("table4"))
873-
val joinPlan1 = Join(
874-
SubqueryAlias("l", SubqueryAlias("t1", table1)),
875-
SubqueryAlias("r", SubqueryAlias("t2", table2)),
876-
Inner,
877-
Some(EqualTo(UnresolvedAttribute("t1.id"), UnresolvedAttribute("t2.id"))),
878-
JoinHint.NONE)
879-
val joinPlan2 = Join(
880-
SubqueryAlias("l", joinPlan1),
881-
SubqueryAlias("r", SubqueryAlias("t3", table3)),
882-
Inner,
883-
Some(EqualTo(UnresolvedAttribute("t2.id"), UnresolvedAttribute("t3.id"))),
884-
JoinHint.NONE)
885-
val joinPlan3 = Join(
886-
SubqueryAlias("l", joinPlan2),
887-
SubqueryAlias("r", SubqueryAlias("t4", table4)),
888-
Inner,
889-
Some(EqualTo(UnresolvedAttribute("t3.id"), UnresolvedAttribute("t4.id"))),
890-
JoinHint.NONE)
891-
val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3)
885+
val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext)
892886
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
893887
}
894888

895889
test("test complex backticks subquery alias") {
896-
val context = new CatalystPlanContext
897-
val logPlan = plan(
890+
val originPlan = plan(
898891
pplParser,
899892
s"""
900893
| source = $testTable1
901-
| | JOIN left = `t1` right = `t2` ON t1.name = t2.name [ source = $testTable2 as `ttt` ] as `tt`
894+
| | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as ttt ] as tt
902895
| | fields t1.name, t2.name
903896
| """.stripMargin)
904-
val logicalPlan = planTransformer.visit(logPlan, context)
905-
val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1"))
906-
val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2"))
907-
val joinPlan1 = Join(
908-
SubqueryAlias("t1", table1),
909-
SubqueryAlias("t2", SubqueryAlias("tt", SubqueryAlias("ttt", table2))),
910-
Inner,
911-
Some(EqualTo(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name"))),
912-
JoinHint.NONE)
913-
val expectedPlan =
914-
Project(Seq(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name")), joinPlan1)
897+
val expectedPlan = planTransformer.visit(originPlan, new CatalystPlanContext)
898+
val logPlan = plan(
899+
pplParser,
900+
s"""
901+
| source = $testTable1
902+
| | JOIN left = `t1` right = `t2` ON `t1`.`name` = `t2`.`name` [ source = $testTable2 as `ttt` ] as `tt`
903+
| | fields `t1`.`name`, `t2`.`name`
904+
| """.stripMargin)
905+
val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext)
915906
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
916907
}
917908
}

ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala

+5-11
Original file line numberDiff line numberDiff line change
@@ -126,20 +126,14 @@ class PPLLogicalPlanRenameTranslatorTestSuite
126126
}
127127

128128
test("test rename with backticks alias") {
129-
val context = new CatalystPlanContext
129+
val expectedPlan =
130+
planTransformer.visit(
131+
plan(pplParser, "source=t | rename a as r_a, b as r_b | fields c"),
132+
new CatalystPlanContext)
130133
val logPlan =
131134
planTransformer.visit(
132135
plan(pplParser, "source=t | rename `a` as `r_a`, `b` as `r_b` | fields `c`"),
133-
context)
134-
val renameProjectList: Seq[NamedExpression] =
135-
Seq(
136-
UnresolvedStar(None),
137-
Alias(UnresolvedAttribute("a"), "r_a")(),
138-
Alias(UnresolvedAttribute("b"), "r_b")())
139-
val innerProject = Project(renameProjectList, UnresolvedRelation(Seq("t")))
140-
val planDropColumn =
141-
DataFrameDropColumns(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), innerProject)
142-
val expectedPlan = Project(seq(UnresolvedAttribute("c")), planDropColumn)
136+
new CatalystPlanContext)
143137
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
144138
}
145139
}

ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala

+6-17
Original file line numberDiff line numberDiff line change
@@ -94,25 +94,14 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite
9494
}
9595

9696
test("test trendline with sort and backticks alias") {
97-
val context = new CatalystPlanContext
97+
val expectedPlan =
98+
planTransformer.visit(
99+
plan(pplParser, "source=relation | trendline sort - age sma(3, age) as age_sma"),
100+
new CatalystPlanContext)
98101
val logPlan =
99102
planTransformer.visit(
100-
plan(pplParser, "source=relation | trendline sort - age sma(3, age) as `age_sma`"),
101-
context)
102-
103-
val table = UnresolvedRelation(Seq("relation"))
104-
val ageField = UnresolvedAttribute("age")
105-
val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table)
106-
val countWindow = new WindowExpression(
107-
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
108-
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
109-
val smaWindow = WindowExpression(
110-
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
111-
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
112-
val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow)
113-
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")())
114-
val expectedPlan =
115-
Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort))
103+
plan(pplParser, "source=relation | trendline sort - `age` sma(3, `age`) as `age_sma`"),
104+
new CatalystPlanContext)
116105
comparePlans(logPlan, expectedPlan, checkAnalysis = false)
117106
}
118107

0 commit comments

Comments
 (0)