diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala index 9859a552e..71e451f03 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLRenameITSuite.scala @@ -186,4 +186,67 @@ class FlintSparkPPLRenameITSuite val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + + test("test multiple renamed fields with backticks alias") { + val frame = sql(s""" + | source = $testTable | rename name as `renamed_name`, country as `renamed_country` | fields `renamed_name`, `age`, `renamed_country` + | """.stripMargin) + + val expectedResults: Array[Row] = + Array( + Row("Jake", 70, "USA"), + Row("Hello", 30, "USA"), + Row("John", 25, "Canada"), + Row("Jane", 20, "Canada")) + assertSameRows(expectedResults, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq( + UnresolvedAttribute("renamed_name"), + UnresolvedAttribute("age"), + UnresolvedAttribute("renamed_country")) + val renameProjectList = + Seq( + UnresolvedStar(None), + Alias(UnresolvedAttribute("name"), "renamed_name")(), + Alias(UnresolvedAttribute("country"), "renamed_country")()) + val innerProject = Project(renameProjectList, table) + val planDropColumn = DataFrameDropColumns( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("country")), + innerProject) + val expectedPlan = Project(fieldsProjectList, planDropColumn) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test renamed field with backticks alias used in aggregation") { + val frame = sql(s""" + | source = $testTable | rename age as `user_age` | stats avg(`user_age`) by country + | """.stripMargin) + + val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA")) + assertSameRows(expectedResults, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val renameProjectList = + Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "user_age")()) + val aggregateExpressions = + Seq( + Alias( + UnresolvedFunction( + Seq("AVG"), + Seq(UnresolvedAttribute("user_age")), + isDistinct = false), + "avg(`user_age`)")(), + Alias(UnresolvedAttribute("country"), "country")()) + val innerProject = Project(renameProjectList, table) + val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject) + val aggregatePlan = Aggregate( + Seq(Alias(UnresolvedAttribute("country"), "country")()), + aggregateExpressions, + planDropColumn) + val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index edf532ac9..494ea4fce 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -349,7 +349,7 @@ joinType ; sideAlias - : (LEFT EQUAL leftAlias = ident)? COMMA? (RIGHT EQUAL rightAlias = ident)? + : (LEFT EQUAL leftAlias = qualifiedName)? COMMA? (RIGHT EQUAL rightAlias = qualifiedName)? ; joinCriteria diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 024c0e7a5..304e7c888 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -162,13 +162,13 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct joinType = Join.JoinType.CROSS; } Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); - Optional leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(ctx.sideAlias().leftAlias.getText()) : Optional.empty(); + Optional leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(internalVisitExpression(ctx.sideAlias().leftAlias).toString()) : Optional.empty(); Optional rightAlias = Optional.empty(); if (ctx.tableOrSubqueryClause().alias != null) { - rightAlias = Optional.of(ctx.tableOrSubqueryClause().alias.getText()); + rightAlias = Optional.of(internalVisitExpression(ctx.tableOrSubqueryClause().alias).toString()); } if (ctx.sideAlias().rightAlias != null) { - rightAlias = Optional.of(ctx.sideAlias().rightAlias.getText()); + rightAlias = Optional.of(internalVisitExpression(ctx.sideAlias().rightAlias).toString()); } UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause()); @@ -248,7 +248,7 @@ public UnresolvedPlan visitRenameCommand(OpenSearchPPLParser.RenameCommandContex .map( ct -> new Alias( - ct.renamedField.getText(), + ((Field) internalVisitExpression(ct.renamedField)).getField().toString(), internalVisitExpression(ct.orignalField))) .collect(Collectors.toList())); } @@ -262,7 +262,7 @@ public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext String name = aggCtx.alias == null ? getTextInQuery(aggCtx) - : aggCtx.alias.getText(); + : ((Field) internalVisitExpression(aggCtx.alias)).getField().toString(); Alias alias = new Alias(name, aggExpression); aggListBuilder.add(alias); } @@ -442,7 +442,7 @@ private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParse throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 1"); } Field dataField = (Field) expressionBuilder.visitFieldExpression(ctx.field); - String alias = ctx.alias == null?dataField.getField().toString()+"_trendline":ctx.alias.getText(); + String alias = ctx.alias == null? dataField.getField().toString() + "_trendline" : internalVisitExpression(ctx.alias).toString(); String computationType = ctx.trendlineType().getText(); return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, Trendline.TrendlineType.valueOf(computationType.toUpperCase())); } @@ -537,7 +537,7 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubqueryClauseContext ctx) { if (ctx.subSearch() != null) { return ctx.alias != null - ? new SubqueryAlias(ctx.alias.getText(), visitSubSearch(ctx.subSearch())) + ? new SubqueryAlias(internalVisitExpression(ctx.alias).toString(), visitSubSearch(ctx.subSearch())) : visitSubSearch(ctx.subSearch()); } else { return visitTableSourceClause(ctx.tableSourceClause()); @@ -547,7 +547,7 @@ public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubq @Override public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { Relation relation = new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); - return ctx.alias != null ? new SubqueryAlias(ctx.alias.getText(), relation) : relation; + return ctx.alias != null ? new SubqueryAlias(internalVisitExpression(ctx.alias).toString(), relation) : relation; } @Override diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 42cc7ed10..53e942b5d 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -1105,4 +1105,14 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + + test("test average price with backticks alias") { + val expectedPlan = planTransformer.visit( + plan(pplParser, "source = table | stats avg(price) as avg_price"), + new CatalystPlanContext) + val logPlan = planTransformer.visit( + plan(pplParser, "source = table | stats avg(`price`) as `avg_price`"), + new CatalystPlanContext) + comparePlans(expectedPlan, logPlan, false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 1f081bd72..9ec160480 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -394,6 +394,29 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("Search multiple tables - with backticks table alias") { + val expectedPlan = + planTransformer.visit( + plan( + pplParser, + """ + | source=table1, table2, table3 as t + | | where t.name = 'Molly' + |""".stripMargin), + new CatalystPlanContext) + val logPlan = + planTransformer.visit( + plan( + pplParser, + """ + | source=table1, table2, table3 as `t` + | | where `t`.`name` = 'Molly' + |""".stripMargin), + new CatalystPlanContext) + + comparePlans(expectedPlan, logPlan, false) + } + test("test fields + field list") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala index f4ed397e3..badfee7f9 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala @@ -845,4 +845,64 @@ class PPLLogicalPlanJoinTranslatorTestSuite Project(Seq(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name")), joinPlan1) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("test multiple joins with table and subquery backticks alias") { + val originPlan = plan( + pplParser, + s""" + | source = table1 as t1 + | | JOIN left = l right = r ON t1.id = t2.id + | [ + | source = table2 as t2 + | ] + | | JOIN left = l right = r ON t2.id = t3.id + | [ + | source = table3 as t3 + | ] + | | JOIN left = l right = r ON t3.id = t4.id + | [ + | source = table4 as t4 + | ] + | """.stripMargin) + val expectedPlan = planTransformer.visit(originPlan, new CatalystPlanContext) + val logPlan = plan( + pplParser, + s""" + | source = table1 as `t1` + | | JOIN left = `l` right = `r` ON `t1`.`id` = `t2`.`id` + | [ + | source = table2 as `t2` + | ] + | | JOIN left = `l` right = `r` ON `t2`.`id` = `t3`.`id` + | [ + | source = table3 as `t3` + | ] + | | JOIN left = `l` right = `r` ON `t3`.`id` = `t4`.`id` + | [ + | source = table4 as `t4` + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test complex backticks subquery alias") { + val originPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as ttt ] as tt + | | fields t1.name, t2.name + | """.stripMargin) + val expectedPlan = planTransformer.visit(originPlan, new CatalystPlanContext) + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | JOIN left = `t1` right = `t2` ON `t1`.`name` = `t2`.`name` [ source = $testTable2 as `ttt` ] as `tt` + | | fields `t1`.`name`, `t2`.`name` + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala index e02c5b2c4..6c722e22f 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameTranslatorTestSuite.scala @@ -124,4 +124,16 @@ class PPLLogicalPlanRenameTranslatorTestSuite Project(seq(UnresolvedAttribute("eval_rand")), planDropColumn) comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + + test("test rename with backticks alias") { + val expectedPlan = + planTransformer.visit( + plan(pplParser, "source=t | rename a as r_a, b as r_b | fields c"), + new CatalystPlanContext) + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | rename `a` as `r_a`, `b` as `r_b` | fields `c`"), + new CatalystPlanContext) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index ec1775631..cc5d80c62 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -93,6 +93,18 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite comparePlans(logPlan, expectedPlan, checkAnalysis = false) } + test("test trendline with sort and backticks alias") { + val expectedPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort - age sma(3, age) as age_sma"), + new CatalystPlanContext) + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort - `age` sma(3, `age`) as `age_sma`"), + new CatalystPlanContext) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + test("test trendline with multiple trendline sma commands") { val context = new CatalystPlanContext val logPlan =