Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alias could be wrapped with backticks in RENAME and other commands #1066

Merged
merged 2 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,67 @@ class FlintSparkPPLRenameITSuite
val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test multiple renamed fields with backticks alias") {
val frame = sql(s"""
| source = $testTable | rename name as `renamed_name`, country as `renamed_country` | fields `renamed_name`, `age`, `renamed_country`
| """.stripMargin)

val expectedResults: Array[Row] =
Array(
Row("Jake", 70, "USA"),
Row("Hello", 30, "USA"),
Row("John", 25, "Canada"),
Row("Jane", 20, "Canada"))
assertSameRows(expectedResults, frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val fieldsProjectList = Seq(
UnresolvedAttribute("renamed_name"),
UnresolvedAttribute("age"),
UnresolvedAttribute("renamed_country"))
val renameProjectList =
Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("name"), "renamed_name")(),
Alias(UnresolvedAttribute("country"), "renamed_country")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("country")),
innerProject)
val expectedPlan = Project(fieldsProjectList, planDropColumn)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test renamed field with backticks alias used in aggregation") {
val frame = sql(s"""
| source = $testTable | rename age as `user_age` | stats avg(`user_age`) by country
| """.stripMargin)

val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA"))
assertSameRows(expectedResults, frame)

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val renameProjectList =
Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "user_age")())
val aggregateExpressions =
Seq(
Alias(
UnresolvedFunction(
Seq("AVG"),
Seq(UnresolvedAttribute("user_age")),
isDistinct = false),
"avg(`user_age`)")(),
Alias(UnresolvedAttribute("country"), "country")())
val innerProject = Project(renameProjectList, table)
val planDropColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("age")), innerProject)
val aggregatePlan = Aggregate(
Seq(Alias(UnresolvedAttribute("country"), "country")()),
aggregateExpressions,
planDropColumn)
val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ joinType
;

sideAlias
: (LEFT EQUAL leftAlias = ident)? COMMA? (RIGHT EQUAL rightAlias = ident)?
: (LEFT EQUAL leftAlias = qualifiedName)? COMMA? (RIGHT EQUAL rightAlias = qualifiedName)?
;

joinCriteria
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct
joinType = Join.JoinType.CROSS;
}
Join.JoinHint joinHint = getJoinHint(ctx.joinHintList());
Optional<String> leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(ctx.sideAlias().leftAlias.getText()) : Optional.empty();
Optional<String> leftAlias = ctx.sideAlias().leftAlias != null ? Optional.of(internalVisitExpression(ctx.sideAlias().leftAlias).toString()) : Optional.empty();
Optional<String> rightAlias = Optional.empty();
if (ctx.tableOrSubqueryClause().alias != null) {
rightAlias = Optional.of(ctx.tableOrSubqueryClause().alias.getText());
rightAlias = Optional.of(internalVisitExpression(ctx.tableOrSubqueryClause().alias).toString());
}
if (ctx.sideAlias().rightAlias != null) {
rightAlias = Optional.of(ctx.sideAlias().rightAlias.getText());
rightAlias = Optional.of(internalVisitExpression(ctx.sideAlias().rightAlias).toString());
}

UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause());
Expand Down Expand Up @@ -248,7 +248,7 @@ public UnresolvedPlan visitRenameCommand(OpenSearchPPLParser.RenameCommandContex
.map(
ct ->
new Alias(
ct.renamedField.getText(),
((Field) internalVisitExpression(ct.renamedField)).getField().toString(),
internalVisitExpression(ct.orignalField)))
.collect(Collectors.toList()));
}
Expand All @@ -262,7 +262,7 @@ public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext
String name =
aggCtx.alias == null
? getTextInQuery(aggCtx)
: aggCtx.alias.getText();
: ((Field) internalVisitExpression(aggCtx.alias)).getField().toString();
Alias alias = new Alias(name, aggExpression);
aggListBuilder.add(alias);
}
Expand Down Expand Up @@ -442,7 +442,7 @@ private Trendline.TrendlineComputation toTrendlineComputation(OpenSearchPPLParse
throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 1");
}
Field dataField = (Field) expressionBuilder.visitFieldExpression(ctx.field);
String alias = ctx.alias == null?dataField.getField().toString()+"_trendline":ctx.alias.getText();
String alias = ctx.alias == null? dataField.getField().toString() + "_trendline" : internalVisitExpression(ctx.alias).toString();
String computationType = ctx.trendlineType().getText();
return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, Trendline.TrendlineType.valueOf(computationType.toUpperCase()));
}
Expand Down Expand Up @@ -537,7 +537,7 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct
public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubqueryClauseContext ctx) {
if (ctx.subSearch() != null) {
return ctx.alias != null
? new SubqueryAlias(ctx.alias.getText(), visitSubSearch(ctx.subSearch()))
? new SubqueryAlias(internalVisitExpression(ctx.alias).toString(), visitSubSearch(ctx.subSearch()))
: visitSubSearch(ctx.subSearch());
} else {
return visitTableSourceClause(ctx.tableSourceClause());
Expand All @@ -547,7 +547,7 @@ public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubq
@Override
public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) {
Relation relation = new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList()));
return ctx.alias != null ? new SubqueryAlias(ctx.alias.getText(), relation) : relation;
return ctx.alias != null ? new SubqueryAlias(internalVisitExpression(ctx.alias).toString(), relation) : relation;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1105,4 +1105,14 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test average price with backticks alias") {
val expectedPlan = planTransformer.visit(
plan(pplParser, "source = table | stats avg(price) as avg_price"),
new CatalystPlanContext)
val logPlan = planTransformer.visit(
plan(pplParser, "source = table | stats avg(`price`) as `avg_price`"),
new CatalystPlanContext)
comparePlans(expectedPlan, logPlan, false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,29 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
comparePlans(expectedPlan, logPlan, false)
}

test("Search multiple tables - with backticks table alias") {
val expectedPlan =
planTransformer.visit(
plan(
pplParser,
"""
| source=table1, table2, table3 as t
| | where t.name = 'Molly'
|""".stripMargin),
new CatalystPlanContext)
val logPlan =
planTransformer.visit(
plan(
pplParser,
"""
| source=table1, table2, table3 as `t`
| | where `t`.`name` = 'Molly'
|""".stripMargin),
new CatalystPlanContext)

comparePlans(expectedPlan, logPlan, false)
}

test("test fields + field list") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -845,4 +845,64 @@ class PPLLogicalPlanJoinTranslatorTestSuite
Project(Seq(UnresolvedAttribute("t1.name"), UnresolvedAttribute("t2.name")), joinPlan1)
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
}

test("test multiple joins with table and subquery backticks alias") {
val originPlan = plan(
pplParser,
s"""
| source = table1 as t1
| | JOIN left = l right = r ON t1.id = t2.id
| [
| source = table2 as t2
| ]
| | JOIN left = l right = r ON t2.id = t3.id
| [
| source = table3 as t3
| ]
| | JOIN left = l right = r ON t3.id = t4.id
| [
| source = table4 as t4
| ]
| """.stripMargin)
val expectedPlan = planTransformer.visit(originPlan, new CatalystPlanContext)
val logPlan = plan(
pplParser,
s"""
| source = table1 as `t1`
| | JOIN left = `l` right = `r` ON `t1`.`id` = `t2`.`id`
| [
| source = table2 as `t2`
| ]
| | JOIN left = `l` right = `r` ON `t2`.`id` = `t3`.`id`
| [
| source = table3 as `t3`
| ]
| | JOIN left = `l` right = `r` ON `t3`.`id` = `t4`.`id`
| [
| source = table4 as `t4`
| ]
| """.stripMargin)
val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext)
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
}

test("test complex backticks subquery alias") {
val originPlan = plan(
pplParser,
s"""
| source = $testTable1
| | JOIN left = t1 right = t2 ON t1.name = t2.name [ source = $testTable2 as ttt ] as tt
| | fields t1.name, t2.name
| """.stripMargin)
val expectedPlan = planTransformer.visit(originPlan, new CatalystPlanContext)
val logPlan = plan(
pplParser,
s"""
| source = $testTable1
| | JOIN left = `t1` right = `t2` ON `t1`.`name` = `t2`.`name` [ source = $testTable2 as `ttt` ] as `tt`
| | fields `t1`.`name`, `t2`.`name`
| """.stripMargin)
val logicalPlan = planTransformer.visit(logPlan, new CatalystPlanContext)
comparePlans(expectedPlan, logicalPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,16 @@ class PPLLogicalPlanRenameTranslatorTestSuite
Project(seq(UnresolvedAttribute("eval_rand")), planDropColumn)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test rename with backticks alias") {
val expectedPlan =
planTransformer.visit(
plan(pplParser, "source=t | rename a as r_a, b as r_b | fields c"),
new CatalystPlanContext)
val logPlan =
planTransformer.visit(
plan(pplParser, "source=t | rename `a` as `r_a`, `b` as `r_b` | fields `c`"),
new CatalystPlanContext)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite
comparePlans(logPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline with sort and backticks alias") {
val expectedPlan =
planTransformer.visit(
plan(pplParser, "source=relation | trendline sort - age sma(3, age) as age_sma"),
new CatalystPlanContext)
val logPlan =
planTransformer.visit(
plan(pplParser, "source=relation | trendline sort - `age` sma(3, `age`) as `age_sma`"),
new CatalystPlanContext)
comparePlans(logPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline with multiple trendline sma commands") {
val context = new CatalystPlanContext
val logPlan =
Expand Down
Loading