Skip to content

Commit b9a0dc5

Browse files
authored
Alias could be wrapped with backticks in RENAME and other commands (#1066)
1 parent 7c87378 commit b9a0dc5

File tree

8 files changed

+189
-9
lines changed

8 files changed

+189
-9
lines changed

integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala

+63
Original file line numberDiff line numberDiff line change
@@ -186,4 +186,67 @@ class FlintSparkPPLRenameITSuite
186186
val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan)
187187
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
188188
}
189+
190+
test("test multiple renamed fields with backticks alias") {
191+
val frame = sql(s"""
192+
| source = $testTable | rename name as `renamed_name`, country as `renamed_country` | fields `renamed_name`, `age`, `renamed_country`
193+
| """.stripMargin)
194+
195+
val expectedResults: Array[Row] =
196+
Array(
197+
Row("Jake", 70, "USA"),
198+
Row("Hello", 30, "USA"),
199+
Row("John", 25, "Canada"),
200+
Row("Jane", 20, "Canada"))
201+
assertSameRows(expectedResults, frame)
202+
203+
val logicalPlan: LogicalPlan = frame.queryExecution.logical
204+
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
205+
val fieldsProjectList = Seq(
206+
UnresolvedAttribute("renamed_name"),
207+
UnresolvedAttribute("age"),
208+
UnresolvedAttribute("renamed_country"))
209+
val renameProjectList =
210+
Seq(
211+
UnresolvedStar(None),
212+
Alias(UnresolvedAttribute("name"), "renamed_name")(),
213+
Alias(UnresolvedAttribute("country"), "renamed_country")())
214+
val innerProject = Project(renameProjectList, table)
215+
val planDropColumn = DataFrameDropColumns(
216+
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("country")),
217+
innerProject)
218+
val expectedPlan = Project(fieldsProjectList, planDropColumn)
219+
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
220+
}
221+
222+
test("test renamed field with backticks alias used in aggregation") {
223+
val frame = sql(s"""
224+
| source = $testTable | rename age as `user_age` | stats avg(`user_age`) by country
225+
| """.stripMargin)
226+
227+
val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA"))
228+
assertSameRows(expectedResults, frame)
229+
230+
val logicalPlan: LogicalPlan = frame.queryExecution.logical
231+
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
232+
val renameProjectList =
233+
Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "user_age")())
234+
val aggregateExpressions =
235+
Seq(
236+
Alias(
237+
UnresolvedFunction(
238+
Seq("AVG"),
239+
Seq(UnresolvedAttribute("user_age")),
240+
isDistinct = false),
241+
"avg(`user_age`)")(),
242+
Alias(UnresolvedAttribute("country"), "country")())
243+
val innerProject = Project(renameProjectList, table)
244+
val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject)
245+
val aggregatePlan = Aggregate(
246+
Seq(Alias(UnresolvedAttribute("country"), "country")()),
247+
aggregateExpressions,
248+
planDropColumn)
249+
val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan)
250+
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
251+
}
189252
}

ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4

+1-1
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ joinType
349349
;
350350

351351
sideAlias
352-
: (LEFT EQUAL leftAlias = ident)? COMMA? (RIGHT EQUAL rightAlias = ident)?
352+
: (LEFT EQUAL leftAlias = qualifiedName)? COMMA? (RIGHT EQUAL rightAlias = qualifiedName)?
353353
;
354354

355355
joinCriteria

ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java

+8-8
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,13 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct
162162
joinType = Join.JoinType.CROSS;
163163
}
164164
Join.JoinHint joinHint = getJoinHint(ctx.joinHintList());
165-
Optional<String> leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(ctx.sideAlias().leftAlias.getText()) : Optional.empty();
165+
Optional<String> leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(internalVisitExpression(ctx.sideAlias().leftAlias).toString()) : Optional.empty();
166166
Optional<String> rightAlias = Optional.empty();
167167
if (ctx.tableOrSubqueryClause().alias != null) {
168-
rightAlias = Optional.of(ctx.tableOrSubqueryClause().alias.getText());
168+
rightAlias = Optional.of(internalVisitExpression(ctx.tableOrSubqueryClause().alias).toString());
169169
}
170170
if (ctx.sideAlias().rightAlias != null) {
171-
rightAlias = Optional.of(ctx.sideAlias().rightAlias.getText());
171+
rightAlias = Optional.of(internalVisitExpression(ctx.sideAlias().rightAlias).toString());
172172
}
173173

174174
UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause());
@@ -248,7 +248,7 @@ public UnresolvedPlan visitRenameCommand(OpenSearchPPLParser.RenameCommandContex
248248
.map(
249249
ct ->
250250
new Alias(
251-
ct.renamedField.getText(),
251+
((Field) internalVisitExpression(ct.renamedField)).getField().toString(),
252252
internalVisitExpression(ct.orignalField)))
253253
.collect(Collectors.toList()));
254254
}
@@ -262,7 +262,7 @@ public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext
262262
String name =
263263
aggCtx.alias == null
264264
? getTextInQuery(aggCtx)
265-
: aggCtx.alias.getText();
265+
: ((Field) internalVisitExpression(aggCtx.alias)).getField().toString();
266266
Alias alias = new Alias(name, aggExpression);
267267
aggListBuilder.add(alias);
268268
}
@@ -442,7 +442,7 @@ private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParse
442442
throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 1");
443443
}
444444
Field dataField = (Field) expressionBuilder.visitFieldExpression(ctx.field);
445-
String alias = ctx.alias == null?dataField.getField().toString()+"_trendline":ctx.alias.getText();
445+
String alias = ctx.alias == null? dataField.getField().toString() + "_trendline" : internalVisitExpression(ctx.alias).toString();
446446
String computationType = ctx.trendlineType().getText();
447447
return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, Trendline.TrendlineType.valueOf(computationType.toUpperCase()));
448448
}
@@ -537,7 +537,7 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct
537537
public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubqueryClauseContext ctx) {
538538
if (ctx.subSearch() != null) {
539539
return ctx.alias != null
540-
? new SubqueryAlias(ctx.alias.getText(), visitSubSearch(ctx.subSearch()))
540+
? new SubqueryAlias(internalVisitExpression(ctx.alias).toString(), visitSubSearch(ctx.subSearch()))
541541
: visitSubSearch(ctx.subSearch());
542542
} else {
543543
return visitTableSourceClause(ctx.tableSourceClause());
@@ -547,7 +547,7 @@ public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubq
547547
@Override
548548
public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) {
549549
Relation relation = new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList()));
550-
return ctx.alias != null ? new SubqueryAlias(ctx.alias.getText(), relation) : relation;
550+
return ctx.alias != null ? new SubqueryAlias(internalVisitExpression(ctx.alias).toString(), relation) : relation;
551551
}
552552

553553
@Override

ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala

+10
Original file line numberDiff line numberDiff line change
@@ -1105,4 +1105,14 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
11051105
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
11061106
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
11071107
}
1108+
1109+
test("test average price with backticks alias") {
1110+
val expectedPlan = planTransformer.visit(
1111+
plan(pplParser, "source = table | stats avg(price) as avg_price"),
1112+
new CatalystPlanContext)
1113+
val logPlan = planTransformer.visit(
1114+
plan(pplParser, "source = table | stats avg(`price`) as `avg_price`"),
1115+
new CatalystPlanContext)
1116+
comparePlans(expectedPlan, logPlan, false)
1117+
}
11081118
}

ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala

+23
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,29 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
394394
comparePlans(expectedPlan, logPlan, false)
395395
}
396396

397+
test("Search multiple tables - with backticks table alias") {
398+
val expectedPlan =
399+
planTransformer.visit(
400+
plan(
401+
pplParser,
402+
"""
403+
| source=table1, table2, table3 as t
404+
| | where t.name = 'Molly'
405+
|""".stripMargin),
406+
new CatalystPlanContext)
407+
val logPlan =
408+
planTransformer.visit(
409+
plan(
410+
pplParser,
411+
"""
412+
| source=table1, table2, table3 as `t`
413+
| | where `t`.`name` = 'Molly'
414+
|""".stripMargin),
415+
new CatalystPlanContext)
416+
417+
comparePlans(expectedPlan, logPlan, false)
418+
}
419+
397420
test("test fields + field list") {
398421
val context = new CatalystPlanContext
399422
val logPlan = planTransformer.visit(

ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala

+60
Original file line numberDiff line numberDiff line change
@@ -845,4 +845,64 @@ class PPLLogicalPlanJoinTranslatorTestSuite
845845
Project(Seq(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name")), joinPlan1)
846846
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
847847
}
848+
849+
test("test multiple joins with table and subquery backticks alias") {
850+
val originPlan = plan(
851+
pplParser,
852+
s"""
853+
| source = table1 as t1
854+
| | JOIN left = l right = r ON t1.id = t2.id
855+
| [
856+
| source = table2 as t2
857+
| ]
858+
| | JOIN left = l right = r ON t2.id = t3.id
859+
| [
860+
| source = table3 as t3
861+
| ]
862+
| | JOIN left = l right = r ON t3.id = t4.id
863+
| [
864+
| source = table4 as t4
865+
| ]
866+
| """.stripMargin)
867+
val expectedPlan = planTransformer.visit(originPlan, new CatalystPlanContext)
868+
val logPlan = plan(
869+
pplParser,
870+
s"""
871+
| source = table1 as `t1`
872+
| | JOIN left = `l` right = `r` ON `t1`.`id` = `t2`.`id`
873+
| [
874+
| source = table2 as `t2`
875+
| ]
876+
| | JOIN left = `l` right = `r` ON `t2`.`id` = `t3`.`id`
877+
| [
878+
| source = table3 as `t3`
879+
| ]
880+
| | JOIN left = `l` right = `r` ON `t3`.`id` = `t4`.`id`
881+
| [
882+
| source = table4 as `t4`
883+
| ]
884+
| """.stripMargin)
885+
val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext)
886+
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
887+
}
888+
889+
test("test complex backticks subquery alias") {
890+
val originPlan = plan(
891+
pplParser,
892+
s"""
893+
| source = $testTable1
894+
| | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as ttt ] as tt
895+
| | fields t1.name, t2.name
896+
| """.stripMargin)
897+
val expectedPlan = planTransformer.visit(originPlan, new CatalystPlanContext)
898+
val logPlan = plan(
899+
pplParser,
900+
s"""
901+
| source = $testTable1
902+
| | JOIN left = `t1` right = `t2` ON `t1`.`name` = `t2`.`name` [ source = $testTable2 as `ttt` ] as `tt`
903+
| | fields `t1`.`name`, `t2`.`name`
904+
| """.stripMargin)
905+
val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext)
906+
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
907+
}
848908
}

ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala

+12
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,16 @@ class PPLLogicalPlanRenameTranslatorTestSuite
124124
Project(seq(UnresolvedAttribute("eval_rand")), planDropColumn)
125125
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
126126
}
127+
128+
test("test rename with backticks alias") {
129+
val expectedPlan =
130+
planTransformer.visit(
131+
plan(pplParser, "source=t | rename a as r_a, b as r_b | fields c"),
132+
new CatalystPlanContext)
133+
val logPlan =
134+
planTransformer.visit(
135+
plan(pplParser, "source=t | rename `a` as `r_a`, `b` as `r_b` | fields `c`"),
136+
new CatalystPlanContext)
137+
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
138+
}
127139
}

ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala

+12
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite
9393
comparePlans(logPlan, expectedPlan, checkAnalysis = false)
9494
}
9595

96+
test("test trendline with sort and backticks alias") {
97+
val expectedPlan =
98+
planTransformer.visit(
99+
plan(pplParser, "source=relation | trendline sort - age sma(3, age) as age_sma"),
100+
new CatalystPlanContext)
101+
val logPlan =
102+
planTransformer.visit(
103+
plan(pplParser, "source=relation | trendline sort - `age` sma(3, `age`) as `age_sma`"),
104+
new CatalystPlanContext)
105+
comparePlans(logPlan, expectedPlan, checkAnalysis = false)
106+
}
107+
96108
test("test trendline with multiple trendline sma commands") {
97109
val context = new CatalystPlanContext
98110
val logPlan =

0 commit comments

Comments
 (0)