Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do Not Merge][POC] Calcite integration #993

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -194,6 +194,9 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration"))
"com.github.sbt" % "junit-interface" % "0.13.3" % "test",
"org.projectlombok" % "lombok" % "1.18.30",
"com.github.seancfoley" % "ipaddress" % "5.5.1",
"org.apache.calcite" % "calcite-core" % "1.38.0",
"org.apache.calcite" % "calcite-linq4j" % "1.38.0",
"org.apache.calcite" % "calcite-testkit" % "1.38.0" % "test",
),
libraryDependencies ++= deps(sparkVersion),
// ANTLR settings
Original file line number Diff line number Diff line change
@@ -42,6 +42,10 @@ public List<UnresolvedPlan> getChild() {
return ImmutableList.of(left);
}

public List<UnresolvedPlan> getChildren() {
return ImmutableList.of(left, right);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitJoin(this, context);
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite;

import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder.AggCall;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.calcite.utils.AggregateUtils;

public class CalciteAggCallVisitor extends AbstractNodeVisitor<AggCall, CalcitePlanContext> {
private final CalciteRexNodeVisitor rexNodeVisitor;

public CalciteAggCallVisitor(CalciteRexNodeVisitor rexNodeVisitor) {
this.rexNodeVisitor = rexNodeVisitor;
}

public AggCall analyze(UnresolvedExpression unresolved, CalcitePlanContext context) {
return unresolved.accept(this, context);
}

@Override
public AggCall visitAlias(Alias node, CalcitePlanContext context) {
AggCall aggCall = analyze(node.getDelegated(), context);
return aggCall.as(node.getName());
}

@Override
public AggCall visitAggregateFunction(AggregateFunction node, CalcitePlanContext context) {
RexNode field = rexNodeVisitor.analyze(node.getField(), context);
return AggregateUtils.translate(node, field, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite;

import lombok.Getter;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.RelBuilder;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.function.BiFunction;

public class CalcitePlanContext {

public final RelBuilder relBuilder;
public final ExtendedRexBuilder rexBuilder;

@Getter private boolean isResolvingJoinCondition = false;

public CalcitePlanContext(RelBuilder relBuilder) {
this.relBuilder = relBuilder;
this.rexBuilder = new ExtendedRexBuilder(relBuilder.getRexBuilder());
}

public RexNode resolveJoinCondition(
UnresolvedExpression expr,
BiFunction<UnresolvedExpression, CalcitePlanContext, RexNode> transformFunction) {
isResolvingJoinCondition = true;
RexNode result = transformFunction.apply(expr, this);
isResolvingJoinCondition = false;
return result;
}

public static CalcitePlanContext create(FrameworkConfig config) {
return new CalcitePlanContext(RelBuilder.create(config));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite;

import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder.AggCall;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.AllFields;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Join;
import org.opensearch.sql.ast.tree.Lookup;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.SubqueryAlias;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import static org.apache.calcite.sql.SqlKind.AS;
import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_FIRST;
import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_LAST;
import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;

public class CalciteRelNodeVisitor extends AbstractNodeVisitor<Void, CalcitePlanContext> {

private final CalciteRexNodeVisitor rexVisitor;
private final CalciteAggCallVisitor aggVisitor;

public CalciteRelNodeVisitor() {
this.rexVisitor = new CalciteRexNodeVisitor();
this.aggVisitor = new CalciteAggCallVisitor(rexVisitor);
}

public Void analyze(UnresolvedPlan unresolved, CalcitePlanContext context) {
return unresolved.accept(this, context);
}

@Override
public Void visitRelation(Relation node, CalcitePlanContext context) {
for (QualifiedName qualifiedName : node.getQualifiedNames()) {
context.relBuilder.scan(qualifiedName.getParts());
}
if (node.getQualifiedNames().size() > 1) {
context.relBuilder.union(true, node.getQualifiedNames().size());
}
return null;
}

@Override
public Void visitFilter(Filter node, CalcitePlanContext context) {
visitChildren(node, context);
RexNode condition = rexVisitor.analyze(node.getCondition(), context);
context.relBuilder.filter(condition);
return null;
}

@Override
public Void visitProject(Project node, CalcitePlanContext context) {
visitChildren(node, context);
List<RexNode> projectList = node.getProjectList().stream()
.filter(expr -> !(expr instanceof AllFields))
.map(expr -> rexVisitor.analyze(expr, context))
.collect(Collectors.toList());
if (projectList.isEmpty()) {
return null;
}
if (node.isExcluded()) {
context.relBuilder.projectExcept(projectList);
} else {
context.relBuilder.project(projectList);
}
return null;
}

@Override
public Void visitSort(Sort node, CalcitePlanContext context) {
visitChildren(node, context);
List<RexNode> sortList = node.getSortList().stream().map(
expr -> {
RexNode sortField = rexVisitor.analyze(expr, context);
Sort.SortOption sortOption = analyzeSortOption(expr.getFieldArgs());
if (sortOption == DEFAULT_DESC) {
return context.relBuilder.desc(sortField);
} else {
return sortField;
}
}).collect(Collectors.toList());
context.relBuilder.sort(sortList);
return null;
}

private Sort.SortOption analyzeSortOption(List<Argument> fieldArgs) {
Boolean asc = (Boolean) fieldArgs.get(0).getValue().getValue();
Optional<Argument> nullFirst =
fieldArgs.stream().filter(option -> "nullFirst".equals(option.getName())).findFirst();

if (nullFirst.isPresent()) {
Boolean isNullFirst = (Boolean) nullFirst.get().getValue().getValue();
return new Sort.SortOption((asc ? ASC : DESC), (isNullFirst ? NULL_FIRST : NULL_LAST));
}
return asc ? Sort.SortOption.DEFAULT_ASC : DEFAULT_DESC;
}

@Override
public Void visitHead(Head node, CalcitePlanContext context) {
visitChildren(node, context);
context.relBuilder.limit(node.getFrom(), node.getSize());
return null;
}

@Override
public Void visitEval(Eval node, CalcitePlanContext context) {
visitChildren(node, context);
List<String> originalFieldNames = context.relBuilder.peek().getRowType().getFieldNames();
List<RexNode> evalList = node.getExpressionList().stream()
.map(expr -> {
RexNode eval = rexVisitor.analyze(expr, context);
context.relBuilder.projectPlus(eval);
return eval;
}).collect(Collectors.toList());
// Overriding the existing field if the alias has the same name with original field name. For example, eval field = 1
List<String> overriding = evalList.stream().filter(expr -> expr.getKind() == AS)
.map(expr -> ((RexLiteral) ((RexCall) expr).getOperands().get(1)).getValueAs(String.class))
.collect(Collectors.toList());
overriding.retainAll(originalFieldNames);
if (!overriding.isEmpty()) {
List<RexNode> toDrop = context.relBuilder.fields(overriding);
context.relBuilder.projectExcept(toDrop);
}
return null;
}

@Override
public Void visitAggregation(Aggregation node, CalcitePlanContext context) {
visitChildren(node, context);
List<AggCall> aggList = node.getAggExprList().stream()
.map(expr -> aggVisitor.analyze(expr, context))
.collect(Collectors.toList());
List<RexNode> groupByList = node.getGroupExprList().stream()
.map(expr -> rexVisitor.analyze(expr, context))
.collect(Collectors.toList());

UnresolvedExpression span = node.getSpan();
if (!Objects.isNull(span)) {
RexNode spanRex = rexVisitor.analyze(span, context);
groupByList.add(spanRex);
//add span's group alias field (most recent added expression)
}
// List<RexNode> aggList = node.getAggExprList().stream()
// .map(expr -> rexVisitor.analyze(expr, context))
// .collect(Collectors.toList());
// relBuilder.aggregate(relBuilder.groupKey(groupByList),
// aggList.stream().map(rex -> (MyAggregateCall) rex)
// .map(MyAggregateCall::getCall).collect(Collectors.toList()));
context.relBuilder.aggregate(context.relBuilder.groupKey(groupByList), aggList);
return null;
}

@Override
public Void visitJoin(Join node, CalcitePlanContext context) {
List<UnresolvedPlan> children = node.getChildren();
children.forEach(c -> analyze(c, context));
RexNode joinCondition = node.getJoinCondition().map(c -> rexVisitor.analyzeJoinCondition(c, context))
.orElse(context.relBuilder.literal(true));
context.relBuilder.join(JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);
return null;
}

@Override
public Void visitSubqueryAlias(SubqueryAlias node, CalcitePlanContext context) {
visitChildren(node, context);
context.relBuilder.as(node.getAlias());
return null;
}

@Override
public Void visitLookup(Lookup node, CalcitePlanContext context) {
// 1. resolve source side
visitChildren(node, context);
// get sourceOutputFields from top of stack which is used to build final output
List<RexNode> sourceOutputFields = context.relBuilder.fields();

// 2. resolve lookup table
analyze(node.getLookupRelation(), context);
// If the output fields are specified, build a project list for lookup table.
// The mapping fields of lookup table should be added in this project list, otherwise join will fail.
// So the mapping fields of lookup table should be dropped after join.
List<RexNode> projectList = JoinAndLookupUtils.buildLookupRelationProjectList(node, rexVisitor, context);
if (!projectList.isEmpty()) {
context.relBuilder.project(projectList);
}

// 3. resolve join condition
RexNode joinCondition = JoinAndLookupUtils.buildLookupMappingCondition(node)
.map(c -> rexVisitor.analyzeJoinCondition(c, context))
.orElse(context.relBuilder.literal(true));

// 4. If no output field is specified, all fields from lookup table are applied to the output.
if (node.allFieldsShouldAppliedToOutputList()) {
context.relBuilder.join(JoinRelType.LEFT, joinCondition);
return null;
}

// 5. push join to stack
context.relBuilder.join(JoinRelType.LEFT, joinCondition);

// 6. Drop the mapping fields of lookup table in result:
// For example, in command "LOOKUP lookTbl Field1 AS Field2, Field3",
// the Field1 and Field3 are projection fields and join keys which will be dropped in result.
List<Field> mappingFieldsOfLookup = node.getLookupMappingMap().entrySet().stream()
.map(kv -> kv.getKey().getField() == kv.getValue().getField() ? JoinAndLookupUtils.buildFieldWithLookupSubqueryAlias(node, kv.getKey()) : kv.getKey())
.collect(Collectors.toList());
List<RexNode> dropListOfLookupMappingFields =
JoinAndLookupUtils.buildProjectListFromFields(mappingFieldsOfLookup, rexVisitor, context);
// Drop the $sourceOutputField if existing
List<RexNode> dropListOfSourceFields =
node.getFieldListWithSourceSubqueryAlias().stream().map( field -> {
try {
return rexVisitor.analyze(field, context);
} catch (RuntimeException e) {
// If the field is not found in the source, skip it
return null;
}
}).filter(Objects::nonNull).collect(Collectors.toList());
List<RexNode> toDrop = new ArrayList<>(dropListOfLookupMappingFields);
toDrop.addAll(dropListOfSourceFields);

// 7. build final outputs
List<RexNode> outputFields = new ArrayList<>(sourceOutputFields);
// Add new columns based on different strategies:
// Append: coalesce($outputField, $"inputField").as(outputFieldName)
// Replace: $outputField.as(outputFieldName)
outputFields.addAll(JoinAndLookupUtils.buildOutputProjectList(node, rexVisitor, context));
outputFields.removeAll(toDrop);

context.relBuilder.project(outputFields);

return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.DateString;
import org.apache.calcite.util.TimeString;
import org.apache.calcite.util.TimestampString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.And;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.calcite.utils.BuiltinFunctionUtils;
import org.opensearch.sql.ppl.utils.DataTypeTransformer;

import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.opensearch.sql.ast.expression.SpanUnit.NONE;
import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN;

public class CalciteRexNodeVisitor extends AbstractNodeVisitor<RexNode, CalcitePlanContext> {

public RexNode analyze(UnresolvedExpression unresolved, CalcitePlanContext context) {
return unresolved.accept(this, context);
}

public RexNode analyzeJoinCondition(UnresolvedExpression unresolved, CalcitePlanContext context) {
return context.resolveJoinCondition(unresolved, this::analyze);
}

@Override
public RexNode visitLiteral(Literal node, CalcitePlanContext context) {
RexBuilder rexBuilder = context.rexBuilder;
RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
final Object value = node.getValue();
if (value == null) {
final RelDataType type = typeFactory.createSqlType(SqlTypeName.NULL);
return rexBuilder.makeNullLiteral(type);
}
switch (node.getType()) {
case NULL:
return rexBuilder.makeNullLiteral(typeFactory.createSqlType(SqlTypeName.NULL));
case STRING:
return rexBuilder.makeLiteral(value.toString());
case INTEGER:
return rexBuilder.makeExactLiteral(new BigDecimal((Integer) value));
case LONG:
return rexBuilder.makeBigintLiteral(new BigDecimal((Long) value));
case SHORT:
return rexBuilder.makeExactLiteral(new BigDecimal((Short) value), typeFactory.createSqlType(SqlTypeName.SMALLINT));
case FLOAT:
return rexBuilder.makeApproxLiteral(new BigDecimal(Float.toString((Float) value)), typeFactory.createSqlType(SqlTypeName.FLOAT));
case DOUBLE:
return rexBuilder.makeApproxLiteral(new BigDecimal(Double.toString((Double) value)), typeFactory.createSqlType(SqlTypeName.DOUBLE));
case BOOLEAN:
return rexBuilder.makeLiteral((Boolean) value);
case DATE:
return rexBuilder.makeDateLiteral(new DateString(value.toString()));
case TIME:
return rexBuilder.makeTimeLiteral(new TimeString(value.toString()), RelDataType.PRECISION_NOT_SPECIFIED);
case TIMESTAMP:
return rexBuilder.makeTimestampLiteral(new TimestampString(value.toString()), RelDataType.PRECISION_NOT_SPECIFIED);
case INTERVAL:
// return rexBuilder.makeIntervalLiteral(BigDecimal.valueOf((long) node.getValue()));
default:
throw new UnsupportedOperationException("Unsupported literal type: " + node.getType());
}
}

@Override
public RexNode visitAnd(And node, CalcitePlanContext context) {
final RelDataType booleanType = context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN);
final RexNode left = analyze(node.getLeft(), context);
final RexNode right = analyze(node.getRight(), context);
return context.rexBuilder.makeCall(booleanType, org.apache.calcite.sql.fun.SqlStdOperatorTable.AND, List.of(left, right));
}

@Override
public RexNode visitOr(Or node, CalcitePlanContext context) {
final RexNode left = analyze(node.getLeft(), context);
final RexNode right = analyze(node.getRight(), context);
return context.relBuilder.or(left, right);
}

@Override
public RexNode visitXor(Xor node, CalcitePlanContext context) {
final RelDataType booleanType = context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN);
final RexNode left = analyze(node.getLeft(), context);
final RexNode right = analyze(node.getRight(), context);
return context.rexBuilder.makeCall(booleanType, SqlStdOperatorTable.BIT_XOR, List.of(left, right));
}

@Override
public RexNode visitNot(Not node, CalcitePlanContext context) {
final RexNode expr = analyze(node.getExpression(), context);
return context.relBuilder.not(expr);
}

@Override
public RexNode visitCompare(Compare node, CalcitePlanContext context) {
final RelDataType booleanType = context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN);
final RexNode left = analyze(node.getLeft(), context);
final RexNode right = analyze(node.getRight(), context);
return context.rexBuilder.makeCall(booleanType, BuiltinFunctionUtils.translate(node.getOperator()), List.of(left, right));
}

@Override
public RexNode visitEqualTo(EqualTo node, CalcitePlanContext context) {
final RexNode left = analyze(node.getLeft(), context);
final RexNode right = analyze(node.getRight(), context);
return context.rexBuilder.equals(left, right);
}

@Override
public RexNode visitQualifiedName(QualifiedName node, CalcitePlanContext context) {
if (context.isResolvingJoinCondition()) {
List<String> parts = node.getParts();
if (parts.size() == 1) { // Handle the case of `id = cid`
try {
return context.relBuilder.field(2, 0, parts.get(0));
} catch (IllegalArgumentException i) {
return context.relBuilder.field(2, 1, parts.get(0));
}
} else if (parts.size() == 2) { // Handle the case of `t1.id = t2.id` or `alias1.id = alias2.id`
return context.relBuilder.field(2, parts.get(0), parts.get(1));
} else if (parts.size() == 3) {
throw new UnsupportedOperationException("Unsupported qualified name: " + node);
}
}
String qualifiedName = node.toString();
List<String> currentFields = context.relBuilder.peek().getRowType().getFieldNames();
if (currentFields.contains(qualifiedName)) {
return context.relBuilder.field(qualifiedName);
} else if (node.getParts().size() == 2) {
List<String> parts = node.getParts();
return context.relBuilder.field(1, parts.get(0), parts.get(1));
} else if (currentFields.stream().noneMatch(f -> f.startsWith(qualifiedName))) {
return context.relBuilder.field(qualifiedName);
}
// Handle the overriding fields, for example, `eval SAL = SAL + 1` will delete the original SAL and add a SAL0
Map<String, String> fieldMap = currentFields.stream()
.collect(Collectors.toMap(s -> s.replaceAll("\\d", ""), s -> s));
if (fieldMap.containsKey(qualifiedName)) {
return context.relBuilder.field(fieldMap.get(qualifiedName));
} else {
return null;
}
}

@Override
public RexNode visitAlias(Alias node, CalcitePlanContext context) {
RexNode expr = analyze(node.getDelegated(), context);
return context.relBuilder.alias(expr, node.getName());
}

@Override
public RexNode visitSpan(Span node, CalcitePlanContext context) {
RexNode field = analyze(node.getField(), context);
RexNode value = analyze(node.getValue(), context);
RelDataTypeFactory typeFactory = context.rexBuilder.getTypeFactory();
SpanUnit unit = node.getUnit();
if (isTimeBased(unit)) {
String datetimeUnitString = DataTypeTransformer.translate(unit);
RexNode interval = context.rexBuilder.makeIntervalLiteral(
new BigDecimal(value.toString()),
new SqlIntervalQualifier(datetimeUnitString, SqlParserPos.ZERO));
// TODO not supported yet
return interval;
} else {
// if the unit is not time base - create a math expression to bucket the span partitions
return context.rexBuilder.makeCall(
typeFactory.createSqlType(SqlTypeName.DOUBLE),
SqlStdOperatorTable.MULTIPLY,
List.of(
context.rexBuilder.makeCall(typeFactory.createSqlType(SqlTypeName.DOUBLE),
SqlStdOperatorTable.FLOOR,
List.of(
context.rexBuilder.makeCall(typeFactory.createSqlType(SqlTypeName.DOUBLE),
SqlStdOperatorTable.DIVIDE,
List.of(field, value)
)
)
),
value));
}

}

private boolean isTimeBased(SpanUnit unit) {
return !(unit == NONE || unit == UNKNOWN);
}

// @Override
// public RexNode visitAggregateFunction(AggregateFunction node, Context context) {
// RexNode field = analyze(node.getField(), context);
// AggregateCall aggregateCall = translateAggregateCall(node, field, relBuilder);
// return new MyAggregateCall(aggregateCall);
// }

@Override
public RexNode visitLet(Let node, CalcitePlanContext context) {
RexNode expr = analyze(node.getExpression(), context);
return context.relBuilder.alias(expr, node.getVar().getField().toString());
}

@Override
public RexNode visitFunction(Function node, CalcitePlanContext context) {
List<RexNode> arguments = node.getFuncArgs().stream()
.map(arg -> analyze(arg, context))
.collect(Collectors.toList());
return context.rexBuilder.makeCall(BuiltinFunctionUtils.translate(node.getFuncName()), arguments);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite;

import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;

public class ExtendedRexBuilder extends RexBuilder {

public ExtendedRexBuilder(RexBuilder rexBuilder) {
super(rexBuilder.getTypeFactory());
}

public RexNode coalesce(RexNode... nodes) {
return this.makeCall(SqlStdOperatorTable.COALESCE, nodes);
}

public RexNode equals(RexNode n1, RexNode n2) {
return this.makeCall(SqlStdOperatorTable.EQUALS, n1, n2);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite;

import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rex.RexNode;

import java.util.List;

public class TimeWindow extends SingleRel {
private final RexNode timeColumn;
private final RexNode windowDuration;
private final RexNode slideDuration;
private final RexNode startTime;

public TimeWindow(
RelOptCluster cluster,
RelTraitSet traits,
RelNode input,
RexNode timeColumn,
RexNode windowDuration,
RexNode slideDuration,
RexNode startTime) {
super(cluster, traits, input);
this.timeColumn = timeColumn;
this.windowDuration = windowDuration;
this.slideDuration = slideDuration;
this.startTime = startTime;
}

@Override
public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
return new TimeWindow(getCluster(), traitSet, sole(inputs),
timeColumn, windowDuration, slideDuration, startTime);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.utils;

import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

public interface AggregateUtils {

static RelBuilder.AggCall translate(AggregateFunction agg, RexNode field, CalcitePlanContext context) {
if (BuiltinFunctionName.ofAggregation(agg.getFuncName()).isEmpty())
throw new IllegalStateException("Unexpected value: " + agg.getFuncName());

// Additional aggregation function operators will be added here
BuiltinFunctionName functionName = BuiltinFunctionName.ofAggregation(agg.getFuncName()).get();
switch (functionName) {
case MAX:
return context.relBuilder.max(field);
case MIN:
return context.relBuilder.min(field);
case MEAN:
throw new UnsupportedOperationException("MEAN is not supported in PPL");
case AVG:
return context.relBuilder.avg(agg.getDistinct(), null, field);
case COUNT:
return context.relBuilder.count(agg.getDistinct(), null, field == null ? ImmutableList.of() : ImmutableList.of(field));
case SUM:
return context.relBuilder.sum(agg.getDistinct(), null, field);
case STDDEV:
return context.relBuilder.aggregateCall(SqlStdOperatorTable.STDDEV, field);
case STDDEV_POP:
return context.relBuilder.aggregateCall(SqlStdOperatorTable.STDDEV_POP, field);
case STDDEV_SAMP:
return context.relBuilder.aggregateCall(SqlStdOperatorTable.STDDEV_SAMP, field);
case PERCENTILE:
return context.relBuilder.aggregateCall(SqlStdOperatorTable.PERCENTILE_CONT, field);
case PERCENTILE_APPROX:
throw new UnsupportedOperationException("PERCENTILE_APPROX is not supported in PPL");
case APPROX_COUNT_DISTINCT:
return context.relBuilder.aggregateCall(SqlStdOperatorTable.APPROX_COUNT_DISTINCT, field);
}
throw new IllegalStateException("Not Supported value: " + agg.getFuncName());
}

static AggregateCall translateAggregateCall(AggregateFunction agg, RexNode field, RelBuilder relBuilder) {
if (BuiltinFunctionName.ofAggregation(agg.getFuncName()).isEmpty())
throw new IllegalStateException("Unexpected value: " + agg.getFuncName());

// Additional aggregation function operators will be added here
BuiltinFunctionName functionName = BuiltinFunctionName.ofAggregation(agg.getFuncName()).get();
boolean isDistinct = agg.getDistinct();
switch (functionName) {
case MAX:
return aggCreate(SqlStdOperatorTable.MAX, isDistinct, field);
case MIN:
return aggCreate(SqlStdOperatorTable.MIN, isDistinct, field);
case MEAN:
throw new UnsupportedOperationException("MEAN is not supported in PPL");
case AVG:
return aggCreate(SqlStdOperatorTable.AVG, isDistinct, field);
case COUNT:
return aggCreate(SqlStdOperatorTable.COUNT, isDistinct, field);
case SUM:
return aggCreate(SqlStdOperatorTable.SUM, isDistinct, field);
}
throw new IllegalStateException("Not Supported value: " + agg.getFuncName());
}

static AggregateCall aggCreate(SqlAggFunction agg, boolean isDistinct, RexNode field) {
int index = ((RexInputRef) field).getIndex();
return AggregateCall.create(agg, isDistinct, false, false, ImmutableList.of(), ImmutableList.of(index), -1, null, RelCollations.EMPTY, field.getType(), null);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.utils;

import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlLibraryOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;

import java.util.Locale;

public interface BuiltinFunctionUtils {

static SqlOperator translate(String op) {
switch (op.toUpperCase(Locale.ROOT)) {
case "AND":
return SqlStdOperatorTable.AND;
case "OR":
return SqlStdOperatorTable.OR;
case "NOT":
return SqlStdOperatorTable.NOT;
case "XOR":
return SqlStdOperatorTable.BIT_XOR;
case "=":
return SqlStdOperatorTable.EQUALS;
case "<>":
case "!=":
return SqlStdOperatorTable.NOT_EQUALS;
case ">":
return SqlStdOperatorTable.GREATER_THAN;
case ">=":
return SqlStdOperatorTable.GREATER_THAN_OR_EQUAL;
case "<":
return SqlStdOperatorTable.LESS_THAN;
case "<=":
return SqlStdOperatorTable.LESS_THAN_OR_EQUAL;
case "+":
return SqlStdOperatorTable.PLUS;
case "-":
return SqlStdOperatorTable.MINUS;
case "*":
return SqlStdOperatorTable.MULTIPLY;
case "/":
return SqlStdOperatorTable.DIVIDE;
// Built-in String Functions
case "LOWER":
return SqlStdOperatorTable.LOWER;
case "LIKE":
return SqlStdOperatorTable.LIKE;
// Built-in Math Functions
case "ABS":
return SqlStdOperatorTable.ABS;
// Built-in Date Functions
case "CURRENT_TIMESTAMP":
return SqlStdOperatorTable.CURRENT_TIMESTAMP;
case "CURRENT_DATE":
return SqlStdOperatorTable.CURRENT_DATE;
case "DATE":
return SqlLibraryOperators.DATE;
case "ADDDATE":
return SqlLibraryOperators.DATE_ADD_SPARK;
case "DATE_ADD":
return SqlLibraryOperators.DATEADD;
// TODO Add more, ref RexImpTable
default:
throw new IllegalArgumentException("Unsupported operator: " + op);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.utils;

import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rex.RexNode;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.And;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.Join;
import org.opensearch.sql.ast.tree.Lookup;
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.calcite.CalciteRexNodeVisitor;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

public interface JoinAndLookupUtils {

static JoinRelType translateJoinType(Join.JoinType joinType) {
switch (joinType) {
case LEFT:
return JoinRelType.LEFT;
case RIGHT:
return JoinRelType.RIGHT;
case FULL:
return JoinRelType.FULL;
case SEMI:
return JoinRelType.SEMI;
case ANTI:
return JoinRelType.ANTI;
case INNER:
default:
return JoinRelType.INNER;
}
}

static Optional<UnresolvedExpression> buildLookupMappingCondition(Lookup node) {
// only equi-join conditions are accepted in lookup command
List<UnresolvedExpression> equiConditions = new ArrayList<>();
for (Map.Entry<Field, Field> entry : node.getLookupMappingMap().entrySet()) {
EqualTo equalTo;
if (entry.getKey().getField() == entry.getValue().getField()) {
Field lookupWithAlias = buildFieldWithLookupSubqueryAlias(node, entry.getKey());
Field sourceWithAlias = buildFieldWithSourceSubqueryAlias(node, entry.getValue());
equalTo = new EqualTo(sourceWithAlias, lookupWithAlias);
} else {
equalTo = new EqualTo(entry.getValue(), entry.getKey());
}

equiConditions.add(equalTo);
}
return equiConditions.stream().reduce(And::new);
}

static Field buildFieldWithLookupSubqueryAlias(Lookup node, Field field) {
return new Field(QualifiedName.of(node.getLookupSubqueryAliasName(), field.getField().toString()));
}

static Field buildFieldWithSourceSubqueryAlias(Lookup node, Field field) {
return new Field(QualifiedName.of(node.getSourceSubqueryAliasName(), field.getField().toString()));
}

/** lookup mapping fields + input fields*/
static List<RexNode> buildLookupRelationProjectList(
Lookup node,
CalciteRexNodeVisitor rexVisitor,
CalcitePlanContext context) {
List<Field> lookupMappingFields = new ArrayList<>(node.getLookupMappingMap().keySet());
List<Field> inputFields = new ArrayList<>(node.getInputFieldList());
if (inputFields.isEmpty()) {
// All fields will be applied to the output if no input field is specified.
return Collections.emptyList();
}
lookupMappingFields.addAll(inputFields);
return buildProjectListFromFields(lookupMappingFields, rexVisitor, context);
}

static List<RexNode> buildProjectListFromFields(
List<Field> fields,
CalciteRexNodeVisitor rexVisitor,
CalcitePlanContext context) {
return fields.stream()
.map(expr -> rexVisitor.analyze(expr, context))
.collect(Collectors.toList());
}

static List<RexNode> buildOutputProjectList(
Lookup node,
CalciteRexNodeVisitor rexVisitor,
CalcitePlanContext context) {
List<RexNode> outputProjectList = new ArrayList<>();
for (Map.Entry<Alias, Field> entry : node.getOutputCandidateMap().entrySet()) {
Alias inputFieldWithAlias = entry.getKey();
Field inputField = (Field) inputFieldWithAlias.getDelegated();
Field outputField = entry.getValue();
RexNode inputCol = rexVisitor.visitField(inputField, context);
RexNode outputCol = rexVisitor.visitField(outputField, context);

RexNode child;
if (node.getOutputStrategy() == Lookup.OutputStrategy.APPEND) {
child = context.rexBuilder.coalesce(outputCol, inputCol);
} else {
child = inputCol;
}
// The result output project list we build here is used to replace the source output,
// for the unmatched rows of left outer join, the outputs are null, so fall back to source output.
RexNode nullSafeOutput = context.rexBuilder.coalesce(child, outputCol);
RexNode withAlias = context.relBuilder.alias(nullSafeOutput, inputFieldWithAlias.getName());
outputProjectList.add(withAlias);
}
return outputProjectList;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl;

import lombok.Getter;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelTraitDef;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.rel2sql.RelToSqlConverter;
import org.apache.calcite.rel.rel2sql.SqlImplementor;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.dialect.SparkSqlDialect;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.test.CalciteAssert;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.tools.Programs;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelRunners;
import org.opensearch.flint.spark.ppl.PPLSyntaxParser;
import org.opensearch.sql.ast.statement.Query;
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.calcite.CalciteRelNodeVisitor;

import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.List;
import java.util.function.UnaryOperator;

import static org.apache.calcite.test.Matchers.hasTree;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.opensearch.flint.spark.ppl.PlaneUtils.plan;

public class CalcitePPLAbstractTest {
@Getter private final Frameworks.ConfigBuilder config;
@Getter private final CalcitePlanContext context;
private final CalciteRelNodeVisitor planTransformer;
private final RelToSqlConverter converter;

public CalcitePPLAbstractTest(CalciteAssert.SchemaSpec... schemaSpecs) {
this.config = config(schemaSpecs);
this.context = createBuilderContext();
this.planTransformer = new CalciteRelNodeVisitor();
this.converter = new RelToSqlConverter(SparkSqlDialect.DEFAULT);
}

public PPLSyntaxParser pplParser = new PPLSyntaxParser();

protected Frameworks.ConfigBuilder config(CalciteAssert.SchemaSpec... schemaSpecs) {
final SchemaPlus rootSchema = Frameworks.createRootSchema(true);
final SchemaPlus schema = CalciteAssert.addSchema(rootSchema, schemaSpecs);
return Frameworks.newConfigBuilder()
.parserConfig(SqlParser.Config.DEFAULT)
.defaultSchema(schema)
.traitDefs((List<RelTraitDef>) null)
.programs(Programs.heuristicJoinOrder(Programs.RULE_SET, true, 2));
}

/** Creates a RelBuilder with default config. */
protected CalcitePlanContext createBuilderContext() {
return createBuilderContext(c -> c);
}

/** Creates a CalcitePlanContext with transformed config. */
private CalcitePlanContext createBuilderContext(UnaryOperator<RelBuilder.Config> transform) {
config.context(Contexts.of(transform.apply(RelBuilder.Config.DEFAULT)));
return CalcitePlanContext.create(config.build());
}

/**
* Get the root RelNode of the given PPL query
*/
public RelNode getRelNode(String ppl) {
Query query = (Query) plan(pplParser, ppl);
planTransformer.analyze(query.getPlan(), context);
RelNode root = context.relBuilder.build();
System.out.println(root.explain());
return root;
}

/**
* Verify the logical plan of the given RelNode
*/
public void verifyLogical(RelNode rel, String expectedLogical) {
assertThat(rel, hasTree(expectedLogical));
}

/**
* Execute and verify the result of the given RelNode
*/
public void verifyResult(RelNode rel, String expectedResult) {
try (PreparedStatement preparedStatement = RelRunners.run(rel)) {
String s = CalciteAssert.toString(preparedStatement.executeQuery());
assertThat(s, is(expectedResult));
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

/**
* Execute and verify the result count of the given RelNode
*/
public void verifyResultCount(RelNode rel, int expectedRows) {
try (PreparedStatement preparedStatement = RelRunners.run(rel)) {
CalciteAssert.checkResultCount(is(expectedRows)).accept(preparedStatement.executeQuery());
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

/**
* Verify the generated Spark SQL of the given RelNode
*/
public void verifyPPLToSparkSQL(RelNode rel, String expected) {
SqlImplementor.Result result = converter.visitRoot(rel);
final SqlNode sqlNode = result.asStatement();
final String sql = sqlNode.toSqlString(SparkSqlDialect.DEFAULT).getSql();
assertThat(sql, is(expected));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.test.CalciteAssert;
import org.junit.Ignore;
import org.junit.Test;

public class CalcitePPLAggregationTest extends CalcitePPLAbstractTest {

public CalcitePPLAggregationTest() {
super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL);
}

@Test
public void testSimpleCount() {
String ppl = "source=EMP | stats count() as c";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalAggregate(group=[{}], c=[COUNT()])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = "c=14\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT COUNT(*) `c`\n"
+ "FROM `scott`.`EMP`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testSimpleAvg() {
String ppl = "source=EMP | stats avg(SAL)";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalAggregate(group=[{}], avg(SAL)=[AVG($5)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = "avg(SAL)=2073.21\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT AVG(`SAL`) `avg(SAL)`\n"
+ "FROM `scott`.`EMP`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testMultipleAggregatesWithAliases() {
String ppl = "source=EMP | stats avg(SAL) as avg_sal, max(SAL) as max_sal, min(SAL) as min_sal, count() as cnt";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalAggregate(group=[{}], avg_sal=[AVG($5)], max_sal=[MAX($5)], min_sal=[MIN($5)], cnt=[COUNT()])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = "avg_sal=2073.21; max_sal=5000.00; min_sal=800.00; cnt=14\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT AVG(`SAL`) `avg_sal`, MAX(`SAL`) `max_sal`, MIN(`SAL`) `min_sal`, COUNT(*) `cnt`\n"
+ "FROM `scott`.`EMP`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testAvgByField() {
String ppl = "source=EMP | stats avg(SAL) by DEPTNO";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalAggregate(group=[{7}], avg(SAL)=[AVG($5)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "DEPTNO=20; avg(SAL)=2175.00\n"
+ "DEPTNO=10; avg(SAL)=2916.66\n"
+ "DEPTNO=30; avg(SAL)=1566.66\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `DEPTNO`, AVG(`SAL`) `avg(SAL)`\n"
+ "FROM `scott`.`EMP`\n"
+ "GROUP BY `DEPTNO`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testAvgBySpan() {
String ppl = "source=EMP | stats avg(SAL) by span(EMPNO, 100) as empno_span";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalAggregate(group=[{1}], avg(SAL)=[AVG($0)])\n"
+ " LogicalProject(SAL=[$5], empno_span=[*(FLOOR(/($0, 100)), 100)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "empno_span=7300.0; avg(SAL)=800.00\n"
+ "empno_span=7400.0; avg(SAL)=1600.00\n"
+ "empno_span=7500.0; avg(SAL)=2112.50\n"
+ "empno_span=7600.0; avg(SAL)=2050.00\n"
+ "empno_span=7700.0; avg(SAL)=2725.00\n"
+ "empno_span=7800.0; avg(SAL)=2533.33\n"
+ "empno_span=7900.0; avg(SAL)=1750.00\n";
verifyResult(root, expectedResult);
}

@Test
public void testAvgBySpanAndFields() {
String ppl = "source=EMP | stats avg(SAL) by span(EMPNO, 500) as empno_span, DEPTNO | sort DEPTNO, empno_span";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[ASC])\n"
+ " LogicalAggregate(group=[{1, 2}], avg(SAL)=[AVG($0)])\n"
+ " LogicalProject(SAL=[$5], DEPTNO=[$7], empno_span=[*(FLOOR(/($0, 500)), 500)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "DEPTNO=10; empno_span=7500.0; avg(SAL)=2916.66\n"
+ "DEPTNO=20; empno_span=7000.0; avg(SAL)=800.00\n"
+ "DEPTNO=20; empno_span=7500.0; avg(SAL)=2518.75\n"
+ "DEPTNO=30; empno_span=7000.0; avg(SAL)=1600.00\n"
+ "DEPTNO=30; empno_span=7500.0; avg(SAL)=1560.00\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `DEPTNO`, FLOOR(`EMPNO` / 500) * 500 `empno_span`, AVG(`SAL`) `avg(SAL)`\n"
+ "FROM `scott`.`EMP`\n"
+ "GROUP BY `DEPTNO`, FLOOR(`EMPNO` / 500) * 500\n"
+ "ORDER BY `DEPTNO` NULLS LAST, 2 NULLS LAST";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Ignore
public void testAvgByTimeSpanAndFields() {
String ppl = "source=EMP | stats avg(SAL) by span(HIREDATE, 1y) as hiredate_span, DEPTNO | sort DEPTNO, hiredate_span";
RelNode root = getRelNode(ppl);
String expectedLogical = "";
verifyLogical(root, expectedLogical);
String expectedResult = "";
verifyResult(root, expectedResult);
}

@Test
public void testCountDistinct() {
String ppl = "source=EMP | stats distinct_count(JOB) by DEPTNO";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalAggregate(group=[{7}], distinct_count(JOB)=[COUNT(DISTINCT $2)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "DEPTNO=20; distinct_count(JOB)=3\n"
+ "DEPTNO=10; distinct_count(JOB)=3\n"
+ "DEPTNO=30; distinct_count(JOB)=3\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `DEPTNO`, COUNT(DISTINCT `JOB`) `distinct_count(JOB)`\n"
+ "FROM `scott`.`EMP`\n"
+ "GROUP BY `DEPTNO`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Ignore
public void testMultipleLevelStats() {
// TODO unsupported
String ppl = "source=EMP | stats avg(SAL) as avg_sal | stats avg(COMM) as avg_comm";
RelNode root = getRelNode(ppl);
String expectedLogical = "";
verifyLogical(root, expectedLogical);
String expectedResult = "";
verifyResult(root, expectedResult);
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl;

import org.apache.calcite.rel.RelNode;

import org.apache.calcite.test.CalciteAssert;
import org.junit.Test;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assertions.assertThrows;

public class CalcitePPLBasicTest extends CalcitePPLAbstractTest {

public CalcitePPLBasicTest() {
super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL);
}

@Test
public void testInvalidTable() {
String ppl = "source=unknown";
try {
RelNode root = getRelNode(ppl);
fail("expected error, got " + root);
} catch (Exception e) {
assertThat(e.getMessage(), is("Table 'unknown' not found"));
}
}

@Test
public void testScanTable() {
String ppl = "source=products_temporal";
RelNode root = getRelNode(ppl);
verifyLogical(root, "LogicalTableScan(table=[[scott, products_temporal]])\n");
}

@Test
public void testScanTableTwoParts() {
String ppl = "source=`scott`.`products_temporal`";
RelNode root = getRelNode(ppl);
verifyLogical(root, "LogicalTableScan(table=[[scott, products_temporal]])\n");
}

@Test
public void testFilterQuery() {
String ppl = "source=scott.products_temporal | where SUPPLIER > 0 AND ID = '1000'";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalFilter(condition=[AND(>($1, 0), =($0, '1000'))])\n"
+ " LogicalTableScan(table=[[scott, products_temporal]])\n";
verifyLogical(root, expectedLogical);

String expectedSparkSql = ""
+ "SELECT *\n"
+ "FROM `scott`.`products_temporal`\n"
+ "WHERE `SUPPLIER` > 0 AND `ID` = '1000'";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testFilterQueryWithOr() {
String ppl = "source=EMP | where (DEPTNO = 20 or MGR = 30) and SAL > 1000 | fields EMPNO, ENAME";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0], ENAME=[$1])\n"
+ " LogicalFilter(condition=[AND(OR(=($7, 20), =($3, 30)), >($5, 1000))])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);

String expectedSparkSql = ""
+ "SELECT `EMPNO`, `ENAME`\n"
+ "FROM `scott`.`EMP`\n"
+ "WHERE (`DEPTNO` = 20 OR `MGR` = 30) AND `SAL` > 1000";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testFilterQueryWithOr2() {
String ppl = "source=EMP (DEPTNO = 20 or MGR = 30) and SAL > 1000 | fields EMPNO, ENAME";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0], ENAME=[$1])\n"
+ " LogicalFilter(condition=[AND(OR(=($7, 20), =($3, 30)), >($5, 1000))])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);

String expectedSparkSql = ""
+ "SELECT `EMPNO`, `ENAME`\n"
+ "FROM `scott`.`EMP`\n"
+ "WHERE (`DEPTNO` = 20 OR `MGR` = 30) AND `SAL` > 1000";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testQueryWithFields() {
String ppl = "source=products_temporal | fields SUPPLIER, ID";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(SUPPLIER=[$1], ID=[$0])\n"
+ " LogicalTableScan(table=[[scott, products_temporal]])\n";
verifyLogical(root, expectedLogical);

String expectedSparkSql = ""
+ "SELECT `SUPPLIER`, `ID`\n"
+ "FROM `scott`.`products_temporal`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testQueryMinusFields() {
String ppl = "source=products_temporal | fields - SUPPLIER, ID";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(SYS_START=[$2], SYS_END=[$3])\n"
+ " LogicalTableScan(table=[[scott, products_temporal]])\n";
verifyLogical(root, expectedLogical);

String expectedSparkSql = ""
+ "SELECT `SYS_START`, `SYS_END`\n"
+ "FROM `scott`.`products_temporal`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testFieldsPlusThenMinus() {
String ppl = "source=EMP | fields + EMPNO, DEPTNO, SAL | fields - DEPTNO, SAL";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
}

@Test
public void testFieldsMinusThenPlusShouldThrowException() {
String ppl = "source=EMP | fields - DEPTNO, SAL | fields + EMPNO, DEPTNO, SAL";
IllegalArgumentException e =
assertThrows(IllegalArgumentException.class, () -> { RelNode root = getRelNode(ppl);});
assertThat(e.getMessage(),
is("field [DEPTNO] not found; input fields are: [EMPNO, ENAME, JOB, MGR, HIREDATE, COMM]"));
}

@Test
public void testScanTableAndCheckResults() {
String ppl = "source=EMP | where DEPTNO = 20";
RelNode root = getRelNode(ppl);
String expectedResult = ""
+ "EMPNO=7369; ENAME=SMITH; JOB=CLERK; MGR=7902; HIREDATE=1980-12-17; SAL=800.00; COMM=null; DEPTNO=20\n"
+ "EMPNO=7566; ENAME=JONES; JOB=MANAGER; MGR=7839; HIREDATE=1981-02-04; SAL=2975.00; COMM=null; DEPTNO=20\n"
+ "EMPNO=7788; ENAME=SCOTT; JOB=ANALYST; MGR=7566; HIREDATE=1987-04-19; SAL=3000.00; COMM=null; DEPTNO=20\n"
+ "EMPNO=7876; ENAME=ADAMS; JOB=CLERK; MGR=7788; HIREDATE=1987-05-23; SAL=1100.00; COMM=null; DEPTNO=20\n"
+ "EMPNO=7902; ENAME=FORD; JOB=ANALYST; MGR=7566; HIREDATE=1981-12-03; SAL=3000.00; COMM=null; DEPTNO=20\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT *\n"
+ "FROM `scott`.`EMP`\n"
+ "WHERE `DEPTNO` = 20";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testSort() {
String ppl = "source=EMP | sort DEPTNO";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalSort(sort0=[$7], dir0=[ASC])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
}

@Test
public void testSortTwoFields() {
String ppl = "source=EMP | sort DEPTNO, SAL";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalSort(sort0=[$7], sort1=[$5], dir0=[ASC], dir1=[ASC])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
}

@Test
public void testSortWithDesc() {
String ppl = "source=EMP | sort + DEPTNO, - SAL | fields EMPNO, DEPTNO, SAL";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0], DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalSort(sort0=[$7], sort1=[$5], dir0=[ASC], dir1=[DESC])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "EMPNO=7839; DEPTNO=10; SAL=5000.00\n"
+ "EMPNO=7782; DEPTNO=10; SAL=2450.00\n"
+ "EMPNO=7934; DEPTNO=10; SAL=1300.00\n"
+ "EMPNO=7788; DEPTNO=20; SAL=3000.00\n"
+ "EMPNO=7902; DEPTNO=20; SAL=3000.00\n"
+ "EMPNO=7566; DEPTNO=20; SAL=2975.00\n"
+ "EMPNO=7876; DEPTNO=20; SAL=1100.00\n"
+ "EMPNO=7369; DEPTNO=20; SAL=800.00\n"
+ "EMPNO=7698; DEPTNO=30; SAL=2850.00\n"
+ "EMPNO=7499; DEPTNO=30; SAL=1600.00\n"
+ "EMPNO=7844; DEPTNO=30; SAL=1500.00\n"
+ "EMPNO=7521; DEPTNO=30; SAL=1250.00\n"
+ "EMPNO=7654; DEPTNO=30; SAL=1250.00\n"
+ "EMPNO=7900; DEPTNO=30; SAL=950.00\n";
verifyResult(root, expectedResult);
}

@Test
public void testSortWithDescAndLimit() {
String ppl = "source=EMP | sort - SAL | fields EMPNO, DEPTNO, SAL | head 3";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0], DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalSort(sort0=[$5], dir0=[DESC], fetch=[3])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "EMPNO=7839; DEPTNO=10; SAL=5000.00\n"
+ "EMPNO=7788; DEPTNO=20; SAL=3000.00\n"
+ "EMPNO=7902; DEPTNO=20; SAL=3000.00\n";
verifyResult(root, expectedResult);
}

@Test
public void testMultipleTables() {
String ppl = "source=EMP, EMP";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalUnion(all=[true])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
}

@Test
public void testMultipleTablesAndFilters() {
String ppl = "source=EMP, EMP DEPTNO = 20 | fields EMPNO, DEPTNO, SAL";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0], DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalFilter(condition=[=($7, 20)])\n"
+ " LogicalUnion(all=[true])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "EMPNO=7369; DEPTNO=20; SAL=800.00\n"
+ "EMPNO=7566; DEPTNO=20; SAL=2975.00\n"
+ "EMPNO=7788; DEPTNO=20; SAL=3000.00\n"
+ "EMPNO=7876; DEPTNO=20; SAL=1100.00\n"
+ "EMPNO=7902; DEPTNO=20; SAL=3000.00\n"
+ "EMPNO=7369; DEPTNO=20; SAL=800.00\n"
+ "EMPNO=7566; DEPTNO=20; SAL=2975.00\n"
+ "EMPNO=7788; DEPTNO=20; SAL=3000.00\n"
+ "EMPNO=7876; DEPTNO=20; SAL=1100.00\n"
+ "EMPNO=7902; DEPTNO=20; SAL=3000.00\n";

verifyResult(root, expectedResult);
}

@Test
public void testLineComments() {
String ppl1 = "source=products_temporal //this is a comment";
verifyLogical(getRelNode(ppl1), "LogicalTableScan(table=[[scott, products_temporal]])\n");
String ppl2 = "source=products_temporal // this is a comment";
verifyLogical(getRelNode(ppl2), "LogicalTableScan(table=[[scott, products_temporal]])\n");
String ppl3 = ""
+ "// test is a new line comment\n"
+ "source=products_temporal // this is a comment\n"
+ "| fields SUPPLIER, ID // this is line comment inner ppl command\n"
+ "////this is a new line comment";
String expectedLogical = ""
+ "LogicalProject(SUPPLIER=[$1], ID=[$0])\n"
+ " LogicalTableScan(table=[[scott, products_temporal]])\n";
verifyLogical(getRelNode(ppl3), expectedLogical);
}

@Test
public void testBlockComments() {
String ppl1 = "/* this is a block comment */ source=products_temporal";
verifyLogical(getRelNode(ppl1), "LogicalTableScan(table=[[scott, products_temporal]])\n");
String ppl2 = "source=products_temporal | /*this is a block comment*/ fields SUPPLIER, ID";
String expectedLogical2 = ""
+ "LogicalProject(SUPPLIER=[$1], ID=[$0])\n"
+ " LogicalTableScan(table=[[scott, products_temporal]])\n";
verifyLogical(getRelNode(ppl2), expectedLogical2);
String ppl3 = ""
+ "/*\n"
+ " * This is a\n"
+ " * multiple\n"
+ " * line\n"
+ " * block\n"
+ " * comment\n"
+ " */\n"
+ "search /* block comment */ source=products_temporal /* block comment */ ID = 0\n"
+ "| /*\n"
+ " This is a\n"
+ " multiple\n"
+ " line\n"
+ " block\n"
+ " comment */ fields SUPPLIER, ID /* block comment */\n"
+ "/* block comment */";
String expectedLogical3 = ""
+ "LogicalProject(SUPPLIER=[$1], ID=[$0])\n"
+ " LogicalFilter(condition=[=($0, 0)])\n"
+ " LogicalTableScan(table=[[scott, products_temporal]])\n";
verifyLogical(getRelNode(ppl3), expectedLogical3);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.test.CalciteAssert;
import org.junit.Test;

import java.time.LocalDate;

public class CalcitePPLDateTimeFunctionTest extends CalcitePPLAbstractTest {

public CalcitePPLDateTimeFunctionTest() {
super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL);
}

@Test
public void testDateAndCurrentTimestamp() {
String ppl = "source=EMP | eval added = DATE(CURRENT_TIMESTAMP()) | fields added | head 1";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalSort(fetch=[1])\n"
+ " LogicalProject(added=[DATE(CURRENT_TIMESTAMP)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = "added=" + LocalDate.now() + "\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT DATE(CURRENT_TIMESTAMP) `added`\n"
+ "FROM `scott`.`EMP`\n"
+ "LIMIT 1";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testCurrentDate() {
String ppl = "source=EMP | eval added = CURRENT_DATE() | fields added | head 1";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalSort(fetch=[1])\n"
+ " LogicalProject(added=[CURRENT_DATE])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = "added=" + LocalDate.now() + "\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT CURRENT_DATE `added`\n"
+ "FROM `scott`.`EMP`\n"
+ "LIMIT 1";
verifyPPLToSparkSQL(root, expectedSparkSql);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.test.CalciteAssert;
import org.junit.Test;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

public class CalcitePPLEvalTest extends CalcitePPLAbstractTest {

public CalcitePPLEvalTest() {
super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL);
}

@Test
public void testEval1() {
String ppl = "source=EMP | eval a = 1";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], a=[1])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);

String expectedSparkSql = ""
+ "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`, 1 `a`\n"
+ "FROM `scott`.`EMP`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testEvalAndFields() {
String ppl = "source=EMP | eval a = 1 | fields EMPNO, a";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0], a=[1])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "EMPNO=7369; a=1\n"
+ "EMPNO=7499; a=1\n"
+ "EMPNO=7521; a=1\n"
+ "EMPNO=7566; a=1\n"
+ "EMPNO=7654; a=1\n"
+ "EMPNO=7698; a=1\n"
+ "EMPNO=7782; a=1\n"
+ "EMPNO=7788; a=1\n"
+ "EMPNO=7839; a=1\n"
+ "EMPNO=7844; a=1\n"
+ "EMPNO=7876; a=1\n"
+ "EMPNO=7900; a=1\n"
+ "EMPNO=7902; a=1\n"
+ "EMPNO=7934; a=1\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `EMPNO`, 1 `a`\n"
+ "FROM `scott`.`EMP`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testEval2() {
String ppl = "source=EMP | eval a = 1, b = 2";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], a=[1], b=[2])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);

String expectedSparkSql = ""
+ "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`, 1 `a`, 2 `b`\n"
+ "FROM `scott`.`EMP`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testEval3() {
String ppl = "source=EMP | eval a = 1 | eval b = 2 | eval c = 3";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], a=[1], b=[2], c=[3])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);

String expectedSparkSql = ""
+ "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`, 1 `a`, 2 `b`, 3 `c`\n"
+ "FROM `scott`.`EMP`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testEvalWithSort() {
String ppl = "source=EMP | eval a = EMPNO | sort - a | fields a";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(a=[$8])\n"
+ " LogicalSort(sort0=[$8], dir0=[DESC])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], a=[$0])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "a=7934\n"
+ "a=7902\n"
+ "a=7900\n"
+ "a=7876\n"
+ "a=7844\n"
+ "a=7839\n"
+ "a=7788\n"
+ "a=7782\n"
+ "a=7698\n"
+ "a=7654\n"
+ "a=7566\n"
+ "a=7521\n"
+ "a=7499\n"
+ "a=7369\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `EMPNO` `a`\n"
+ "FROM `scott`.`EMP`\n"
+ "ORDER BY `EMPNO` DESC NULLS FIRST";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testEvalUsingExistingFields() {
String ppl = "source=EMP | eval EMPNO_PLUS = EMPNO + 1 | sort - EMPNO_PLUS | fields EMPNO, EMPNO_PLUS | head 3";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0], EMPNO_PLUS=[$8])\n"
+ " LogicalSort(sort0=[$8], dir0=[DESC], fetch=[3])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], EMPNO_PLUS=[+($0, 1)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "EMPNO=7934; EMPNO_PLUS=7935\n"
+ "EMPNO=7902; EMPNO_PLUS=7903\n"
+ "EMPNO=7900; EMPNO_PLUS=7901\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `EMPNO`, `EMPNO_PLUS`\n"
+ "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`, `EMPNO` + 1 `EMPNO_PLUS`\n"
+ "FROM `scott`.`EMP`\n"
+ "ORDER BY 9 DESC NULLS FIRST\n"
+ "LIMIT 3) `t0`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testEvalOverridingExistingFields() {
String ppl = "source=EMP | eval SAL = DEPTNO + 10000 | sort - EMPNO | fields EMPNO, SAL | head 3";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(EMPNO=[$0], SAL0=[$7])\n"
+ " LogicalSort(sort0=[$0], dir0=[DESC], fetch=[3])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], COMM=[$6], DEPTNO=[$7], SAL0=[+($7, 10000)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "EMPNO=7934; SAL0=10010\n"
+ "EMPNO=7902; SAL0=10020\n"
+ "EMPNO=7900; SAL0=10030\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `EMPNO`, `DEPTNO` + 10000 `SAL0`\n"
+ "FROM `scott`.`EMP`\n"
+ "ORDER BY `EMPNO` DESC NULLS FIRST\n"
+ "LIMIT 3";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testComplexEvalCommands1() {
String ppl = "source=EMP | eval col1 = 1 | sort col1 | head 4 | eval col2 = 2 | sort - col2 | sort EMPNO | head 2 | fields EMPNO, ENAME, col2";
RelNode root = getRelNode(ppl);
String expectedResult = ""
+ "EMPNO=7369; ENAME=SMITH; col2=2\n"
+ "EMPNO=7499; ENAME=ALLEN; col2=2\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `EMPNO`, `ENAME`, `col2`\n"
+ "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`, `col1`, `col2`\n"
+ "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`, 1 `col1`, 2 `col2`\n"
+ "FROM `scott`.`EMP`\n"
+ "ORDER BY '1' NULLS LAST\n"
+ "LIMIT 4) `t1`\n"
+ "ORDER BY `col2` DESC NULLS FIRST) `t2`\n"
+ "ORDER BY `EMPNO` NULLS LAST\n"
+ "LIMIT 2";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testComplexEvalCommands2() {
String ppl = "source=EMP | eval col1 = SAL | sort - col1 | head 3 | eval col2 = SAL | sort + col2 | fields ENAME, SAL | head 2";
RelNode root = getRelNode(ppl);
String expectedResult = ""
+ "ENAME=SCOTT; SAL=3000.00\n"
+ "ENAME=FORD; SAL=3000.00\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `ENAME`, `SAL`\n"
+ "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`, `SAL` `col1`, `SAL` `col2`\n"
+ "FROM `scott`.`EMP`\n"
+ "ORDER BY `SAL` DESC NULLS FIRST\n"
+ "LIMIT 3) `t1`\n"
+ "ORDER BY `col2` NULLS LAST\n"
+ "LIMIT 2";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testComplexEvalCommands3() {
String ppl = "source=EMP | eval col1 = SAL | sort - col1 | head 3 | fields ENAME, col1 | eval col2 = col1 | sort + col2 | fields ENAME, col2 | eval col3 = col2 | head 2 | fields ENAME, col3";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(ENAME=[$0], col3=[$2])\n"
+ " LogicalSort(sort0=[$2], dir0=[ASC], fetch=[2])\n"
+ " LogicalProject(ENAME=[$1], col1=[$8], col2=[$8])\n"
+ " LogicalSort(sort0=[$8], dir0=[DESC], fetch=[3])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], col1=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "ENAME=SCOTT; col3=3000.00\n"
+ "ENAME=FORD; col3=3000.00\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `ENAME`, `col2` `col3`\n"
+ "FROM (SELECT `ENAME`, `SAL` `col1`, `SAL` `col2`\n"
+ "FROM `scott`.`EMP`\n"
+ "ORDER BY `SAL` DESC NULLS FIRST\n"
+ "LIMIT 3) `t1`\n"
+ "ORDER BY `col2` NULLS LAST\n"
+ "LIMIT 2";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testComplexEvalCommands4() {
String ppl = "source=EMP | eval col1 = SAL | sort - col1 | head 3 | fields ENAME, col1 | eval col2 = col1 | sort + col2 | fields ENAME, col2 | eval col3 = col2 | head 2 | fields HIREDATE, col3";
IllegalArgumentException e =
assertThrows(IllegalArgumentException.class, () -> { RelNode root = getRelNode(ppl);});
assertThat(e.getMessage(),
is("field [HIREDATE] not found; input fields are: [ENAME, col2, col3]"));
}

@Test
public void testEvalWithAggregation() {
String ppl = "source=EMP | eval a = SAL, b = DEPTNO | stats avg(a) by b";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalAggregate(group=[{1}], avg(a)=[AVG($0)])\n"
+ " LogicalProject(a=[$5], b=[$7])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "b=20; avg(a)=2175.00\n"
+ "b=10; avg(a)=2916.66\n"
+ "b=30; avg(a)=1566.66\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `DEPTNO` `b`, AVG(`SAL`) `avg(a)`\n"
+ "FROM `scott`.`EMP`\n"
+ "GROUP BY `DEPTNO`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testDependedEval() {
String ppl = "source=EMP | eval a = SAL | eval b = a + 10000 | stats avg(b) by DEPTNO";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalAggregate(group=[{0}], avg(b)=[AVG($1)])\n"
+ " LogicalProject(DEPTNO=[$7], b=[+($5, 10000)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "DEPTNO=20; avg(b)=12175.00\n"
+ "DEPTNO=10; avg(b)=12916.66\n"
+ "DEPTNO=30; avg(b)=11566.66\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `DEPTNO`, AVG(`SAL` + 10000) `avg(b)`\n"
+ "FROM `scott`.`EMP`\n"
+ "GROUP BY `DEPTNO`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testDependedLateralEval() {
String ppl = "source=EMP | eval a = SAL, b = a + 10000 | stats avg(b) by DEPTNO";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalAggregate(group=[{0}], avg(b)=[AVG($1)])\n"
+ " LogicalProject(DEPTNO=[$7], b=[+($5, 10000)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "DEPTNO=20; avg(b)=12175.00\n"
+ "DEPTNO=10; avg(b)=12916.66\n"
+ "DEPTNO=30; avg(b)=11566.66\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT `DEPTNO`, AVG(`SAL` + 10000) `avg(b)`\n"
+ "FROM `scott`.`EMP`\n"
+ "GROUP BY `DEPTNO`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.test.CalciteAssert;
import org.junit.Test;

public class CalcitePPLJoinTest extends CalcitePPLAbstractTest {

public CalcitePPLJoinTest() {
super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL);
}

@Test
public void testJoinConditionWithTableNames() {
String ppl = "source=EMP | join on EMP.DEPTNO = DEPT.DEPTNO DEPT";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
verifyLogical(root, expectedLogical);
verifyResultCount(root, 14);

String expectedSparkSql = ""
+ "SELECT *\n"
+ "FROM `scott`.`EMP`\n"
+ "INNER JOIN `scott`.`DEPT` ON `EMP`.`DEPTNO` = `DEPT`.`DEPTNO`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testJoinConditionWithAlias() {
String ppl = "source=EMP as e | join on e.DEPTNO = d.DEPTNO DEPT as d";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
verifyLogical(root, expectedLogical);
verifyResultCount(root, 14);

String expectedSparkSql = ""
+ "SELECT *\n"
+ "FROM `scott`.`EMP`\n"
+ "INNER JOIN `scott`.`DEPT` ON `EMP`.`DEPTNO` = `DEPT`.`DEPTNO`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testJoinConditionWithoutTableName() {
String ppl = "source=EMP | join on ENAME = DNAME DEPT";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalJoin(condition=[=($1, $9)], joinType=[inner])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
verifyLogical(root, expectedLogical);
verifyResultCount(root, 0);

String expectedSparkSql = ""
+ "SELECT *\n"
+ "FROM `scott`.`EMP`\n"
+ "INNER JOIN `scott`.`DEPT` ON `EMP`.`ENAME` = `DEPT`.`DNAME`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testJoinWithSpecificAliases() {
String ppl = "source=EMP | join left = l right = r on l.DEPTNO = r.DEPTNO DEPT";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
verifyLogical(root, expectedLogical);
verifyResultCount(root, 14);

String expectedSparkSql = ""
+ "SELECT *\n"
+ "FROM `scott`.`EMP`\n"
+ "INNER JOIN `scott`.`DEPT` ON `EMP`.`DEPTNO` = `DEPT`.`DEPTNO`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testLeftJoin() {
String ppl = "source=EMP as e | left join on e.DEPTNO = d.DEPTNO DEPT as d";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalJoin(condition=[=($7, $8)], joinType=[left])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
verifyLogical(root, expectedLogical);
verifyResultCount(root, 14);

String expectedSparkSql = ""
+ "SELECT *\n"
+ "FROM `scott`.`EMP`\n"
+ "LEFT JOIN `scott`.`DEPT` ON `EMP`.`DEPTNO` = `DEPT`.`DEPTNO`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testCrossJoin() {
String ppl = "source=EMP as e | cross join DEPT as d";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalJoin(condition=[true], joinType=[inner])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
verifyLogical(root, expectedLogical);
verifyResultCount(root, 56);

String expectedSparkSql = ""
+ "SELECT *\n"
+ "FROM `scott`.`EMP`\n"
+ "CROSS JOIN `scott`.`DEPT`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testNonEquiJoin() {
String ppl = "source=EMP as e | join on e.DEPTNO > d.DEPTNO DEPT as d";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalJoin(condition=[>($7, $8)], joinType=[inner])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
verifyLogical(root, expectedLogical);
verifyResultCount(root, 17);

String expectedSparkSql = ""
+ "SELECT *\n"
+ "FROM `scott`.`EMP`\n"
+ "INNER JOIN `scott`.`DEPT` ON `EMP`.`DEPTNO` > `DEPT`.`DEPTNO`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.test.CalciteAssert;
import org.junit.Test;

public class CalcitePPLMathFunctionTest extends CalcitePPLAbstractTest {

public CalcitePPLMathFunctionTest() {
super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL);
}

@Test
public void testAbsWithOverriding() {
String ppl = "source=EMP | eval SAL = abs(-30) | head 10 | fields SAL";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(SAL0=[$7])\n"
+ " LogicalSort(fetch=[10])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], COMM=[$6], DEPTNO=[$7], SAL0=[ABS(-30)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult =
"SAL0=30\n" +
"SAL0=30\n" +
"SAL0=30\n" +
"SAL0=30\n" +
"SAL0=30\n" +
"SAL0=30\n" +
"SAL0=30\n" +
"SAL0=30\n" +
"SAL0=30\n" +
"SAL0=30\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT ABS(-30) `SAL0`\n"
+ "FROM `scott`.`EMP`\n"
+ "LIMIT 10";
verifyPPLToSparkSQL(root, expectedSparkSql);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.test.CalciteAssert;
import org.junit.Test;

public class CalcitePPLStringFunctionTest extends CalcitePPLAbstractTest {

public CalcitePPLStringFunctionTest() {
super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL);
}

@Test
public void testLower() {
String ppl = "source=EMP | eval lower_name = lower(ENAME) | fields lower_name";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalProject(lower_name=[LOWER($1)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = ""
+ "lower_name=smith\n"
+ "lower_name=allen\n"
+ "lower_name=ward\n"
+ "lower_name=jones\n"
+ "lower_name=martin\n"
+ "lower_name=blake\n"
+ "lower_name=clark\n"
+ "lower_name=scott\n"
+ "lower_name=king\n"
+ "lower_name=turner\n"
+ "lower_name=adams\n"
+ "lower_name=james\n"
+ "lower_name=ford\n"
+ "lower_name=miller\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT LOWER(`ENAME`) `lower_name`\n"
+ "FROM `scott`.`EMP`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testLike() {
String ppl = "source=EMP | where like(JOB, 'SALE%') | stats count() as cnt";
RelNode root = getRelNode(ppl);
String expectedLogical = ""
+ "LogicalAggregate(group=[{}], cnt=[COUNT()])\n"
+ " LogicalFilter(condition=[LIKE($2, 'SALE%')])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult = "cnt=4\n";
verifyResult(root, expectedResult);

String expectedSparkSql = ""
+ "SELECT COUNT(*) `cnt`\n"
+ "FROM `scott`.`EMP`\n"
+ "WHERE `JOB` LIKE 'SALE%'";
verifyPPLToSparkSQL(root, expectedSparkSql);
}
}