Skip to content

Commit ed0ca8d

Browse files
jduoacarbonetto
andauthored
Add trendline PPL command (#3071)
* Add trendline (With SWA) PPL command --------- Signed-off-by: James Duong <[email protected]> Signed-off-by: Andrew Carbonetto <[email protected]> Co-authored-by: Andrew Carbonetto <[email protected]>
1 parent 3e2cb1d commit ed0ca8d

File tree

33 files changed

+1601
-23
lines changed

33 files changed

+1601
-23
lines changed

core/src/main/java/org/opensearch/sql/analysis/Analyzer.java

+77-17
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_LAST;
1111
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
1212
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
13+
import static org.opensearch.sql.data.type.ExprCoreType.DATE;
1314
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
15+
import static org.opensearch.sql.data.type.ExprCoreType.TIME;
16+
import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP;
1417
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
1518
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
1619
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
@@ -22,6 +25,7 @@
2225
import com.google.common.collect.ImmutableMap;
2326
import com.google.common.collect.ImmutableSet;
2427
import java.util.ArrayList;
28+
import java.util.Collections;
2529
import java.util.List;
2630
import java.util.Objects;
2731
import java.util.Optional;
@@ -62,6 +66,7 @@
6266
import org.opensearch.sql.ast.tree.Sort;
6367
import org.opensearch.sql.ast.tree.Sort.SortOption;
6468
import org.opensearch.sql.ast.tree.TableFunction;
69+
import org.opensearch.sql.ast.tree.Trendline;
6570
import org.opensearch.sql.ast.tree.UnresolvedPlan;
6671
import org.opensearch.sql.ast.tree.Values;
6772
import org.opensearch.sql.common.antlr.SyntaxCheckException;
@@ -100,6 +105,7 @@
100105
import org.opensearch.sql.planner.logical.LogicalRemove;
101106
import org.opensearch.sql.planner.logical.LogicalRename;
102107
import org.opensearch.sql.planner.logical.LogicalSort;
108+
import org.opensearch.sql.planner.logical.LogicalTrendline;
103109
import org.opensearch.sql.planner.logical.LogicalValues;
104110
import org.opensearch.sql.planner.physical.datasource.DataSourceTable;
105111
import org.opensearch.sql.storage.Table;
@@ -469,23 +475,7 @@ public LogicalPlan visitParse(Parse node, AnalysisContext context) {
469475
@Override
470476
public LogicalPlan visitSort(Sort node, AnalysisContext context) {
471477
LogicalPlan child = node.getChild().get(0).accept(this, context);
472-
ExpressionReferenceOptimizer optimizer =
473-
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);
474-
475-
List<Pair<SortOption, Expression>> sortList =
476-
node.getSortList().stream()
477-
.map(
478-
sortField -> {
479-
var analyzed = expressionAnalyzer.analyze(sortField.getField(), context);
480-
if (analyzed == null) {
481-
throw new UnsupportedOperationException(
482-
String.format("Invalid use of expression %s", sortField.getField()));
483-
}
484-
Expression expression = optimizer.optimize(analyzed, context);
485-
return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression);
486-
})
487-
.collect(Collectors.toList());
488-
return new LogicalSort(child, sortList);
478+
return buildSort(child, context, node.getSortList());
489479
}
490480

491481
/** Build {@link LogicalDedupe}. */
@@ -594,6 +584,55 @@ public LogicalPlan visitML(ML node, AnalysisContext context) {
594584
return new LogicalML(child, node.getArguments());
595585
}
596586

587+
/** Build {@link LogicalTrendline} for Trendline command. */
588+
@Override
589+
public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) {
590+
final LogicalPlan child = node.getChild().get(0).accept(this, context);
591+
592+
final TypeEnvironment currEnv = context.peek();
593+
final List<Trendline.TrendlineComputation> computations = node.getComputations();
594+
final ImmutableList.Builder<Pair<Trendline.TrendlineComputation, ExprCoreType>>
595+
computationsAndTypes = ImmutableList.builder();
596+
computations.forEach(
597+
computation -> {
598+
final Expression resolvedField =
599+
expressionAnalyzer.analyze(computation.getDataField(), context);
600+
final ExprCoreType averageType;
601+
// Duplicate the semantics of AvgAggregator#create():
602+
// - All numerical types have the DOUBLE type for the moving average.
603+
// - All datetime types have the same datetime type for the moving average.
604+
if (ExprCoreType.numberTypes().contains(resolvedField.type())) {
605+
averageType = ExprCoreType.DOUBLE;
606+
} else {
607+
switch (resolvedField.type()) {
608+
case DATE:
609+
case TIME:
610+
case TIMESTAMP:
611+
averageType = (ExprCoreType) resolvedField.type();
612+
break;
613+
default:
614+
throw new SemanticCheckException(
615+
String.format(
616+
"Invalid field used for trendline computation %s. Source field %s had type"
617+
+ " %s but must be a numerical or datetime field.",
618+
computation.getAlias(),
619+
computation.getDataField().getChild().get(0),
620+
resolvedField.type().typeName()));
621+
}
622+
}
623+
currEnv.define(new Symbol(Namespace.FIELD_NAME, computation.getAlias()), averageType);
624+
computationsAndTypes.add(Pair.of(computation, averageType));
625+
});
626+
627+
if (node.getSortByField().isEmpty()) {
628+
return new LogicalTrendline(child, computationsAndTypes.build());
629+
}
630+
631+
return new LogicalTrendline(
632+
buildSort(child, context, Collections.singletonList(node.getSortByField().get())),
633+
computationsAndTypes.build());
634+
}
635+
597636
@Override
598637
public LogicalPlan visitPaginate(Paginate paginate, AnalysisContext context) {
599638
LogicalPlan child = paginate.getChild().get(0).accept(this, context);
@@ -612,6 +651,27 @@ public LogicalPlan visitCloseCursor(CloseCursor closeCursor, AnalysisContext con
612651
return new LogicalCloseCursor(closeCursor.getChild().get(0).accept(this, context));
613652
}
614653

654+
private LogicalSort buildSort(
655+
LogicalPlan child, AnalysisContext context, List<Field> sortFields) {
656+
ExpressionReferenceOptimizer optimizer =
657+
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);
658+
659+
List<Pair<SortOption, Expression>> sortList =
660+
sortFields.stream()
661+
.map(
662+
sortField -> {
663+
var analyzed = expressionAnalyzer.analyze(sortField.getField(), context);
664+
if (analyzed == null) {
665+
throw new UnsupportedOperationException(
666+
String.format("Invalid use of expression %s", sortField.getField()));
667+
}
668+
Expression expression = optimizer.optimize(analyzed, context);
669+
return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression);
670+
})
671+
.collect(Collectors.toList());
672+
return new LogicalSort(child, sortList);
673+
}
674+
615675
/**
616676
* The first argument is always "asc", others are optional. Given nullFirst argument, use its
617677
* value. Otherwise just use DEFAULT_ASC/DESC.

core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java

+9
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import org.opensearch.sql.ast.tree.Rename;
6161
import org.opensearch.sql.ast.tree.Sort;
6262
import org.opensearch.sql.ast.tree.TableFunction;
63+
import org.opensearch.sql.ast.tree.Trendline;
6364
import org.opensearch.sql.ast.tree.Values;
6465

6566
/** AST nodes visitor Defines the traverse path. */
@@ -110,6 +111,14 @@ public T visitFilter(Filter node, C context) {
110111
return visitChildren(node, context);
111112
}
112113

114+
public T visitTrendline(Trendline node, C context) {
115+
return visitChildren(node, context);
116+
}
117+
118+
public T visitTrendlineComputation(Trendline.TrendlineComputation node, C context) {
119+
return visitChildren(node, context);
120+
}
121+
113122
public T visitProject(Project node, C context) {
114123
return visitChildren(node, context);
115124
}

core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java

+14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import com.google.common.collect.ImmutableList;
99
import java.util.Arrays;
1010
import java.util.List;
11+
import java.util.Optional;
1112
import java.util.stream.Collectors;
1213
import lombok.experimental.UtilityClass;
1314
import org.apache.commons.lang3.tuple.ImmutablePair;
@@ -62,6 +63,7 @@
6263
import org.opensearch.sql.ast.tree.Sort;
6364
import org.opensearch.sql.ast.tree.Sort.SortOption;
6465
import org.opensearch.sql.ast.tree.TableFunction;
66+
import org.opensearch.sql.ast.tree.Trendline;
6567
import org.opensearch.sql.ast.tree.UnresolvedPlan;
6668
import org.opensearch.sql.ast.tree.Values;
6769

@@ -466,6 +468,18 @@ public static Limit limit(UnresolvedPlan input, Integer limit, Integer offset) {
466468
return new Limit(limit, offset).attach(input);
467469
}
468470

471+
public static Trendline trendline(
472+
UnresolvedPlan input,
473+
Optional<Field> sortField,
474+
Trendline.TrendlineComputation... computations) {
475+
return new Trendline(sortField, Arrays.asList(computations)).attach(input);
476+
}
477+
478+
public static Trendline.TrendlineComputation computation(
479+
Integer numDataPoints, Field dataField, String alias, Trendline.TrendlineType type) {
480+
return new Trendline.TrendlineComputation(numDataPoints, dataField, alias, type);
481+
}
482+
469483
public static Parse parse(
470484
UnresolvedPlan input,
471485
ParseMethod parseMethod,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.ast.tree;
7+
8+
import com.google.common.collect.ImmutableList;
9+
import java.util.List;
10+
import java.util.Optional;
11+
import lombok.EqualsAndHashCode;
12+
import lombok.Getter;
13+
import lombok.RequiredArgsConstructor;
14+
import lombok.ToString;
15+
import org.opensearch.sql.ast.AbstractNodeVisitor;
16+
import org.opensearch.sql.ast.Node;
17+
import org.opensearch.sql.ast.expression.Field;
18+
import org.opensearch.sql.ast.expression.UnresolvedExpression;
19+
20+
@ToString
21+
@Getter
22+
@RequiredArgsConstructor
23+
@EqualsAndHashCode(callSuper = false)
24+
public class Trendline extends UnresolvedPlan {
25+
26+
private UnresolvedPlan child;
27+
private final Optional<Field> sortByField;
28+
private final List<TrendlineComputation> computations;
29+
30+
@Override
31+
public Trendline attach(UnresolvedPlan child) {
32+
this.child = child;
33+
return this;
34+
}
35+
36+
@Override
37+
public List<? extends Node> getChild() {
38+
return ImmutableList.of(child);
39+
}
40+
41+
@Override
42+
public <T, C> T accept(AbstractNodeVisitor<T, C> visitor, C context) {
43+
return visitor.visitTrendline(this, context);
44+
}
45+
46+
@Getter
47+
public static class TrendlineComputation extends UnresolvedExpression {
48+
49+
private final Integer numberOfDataPoints;
50+
private final Field dataField;
51+
private final String alias;
52+
private final TrendlineType computationType;
53+
54+
public TrendlineComputation(
55+
Integer numberOfDataPoints, Field dataField, String alias, TrendlineType computationType) {
56+
this.numberOfDataPoints = numberOfDataPoints;
57+
this.dataField = dataField;
58+
this.alias = alias;
59+
this.computationType = computationType;
60+
}
61+
62+
@Override
63+
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
64+
return nodeVisitor.visitTrendlineComputation(this, context);
65+
}
66+
}
67+
68+
public enum TrendlineType {
69+
SMA
70+
}
71+
}

core/src/main/java/org/opensearch/sql/executor/Explain.java

+32
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import com.google.common.collect.ImmutableMap;
99
import java.util.ArrayList;
1010
import java.util.List;
11+
import java.util.Locale;
1112
import java.util.Map;
1213
import java.util.function.Consumer;
1314
import java.util.function.Function;
1415
import java.util.stream.Collectors;
1516
import org.apache.commons.lang3.tuple.Pair;
1617
import org.opensearch.sql.ast.tree.Sort;
18+
import org.opensearch.sql.ast.tree.Trendline;
1719
import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse;
1820
import org.opensearch.sql.executor.ExecutionEngine.ExplainResponseNode;
1921
import org.opensearch.sql.expression.Expression;
@@ -31,6 +33,7 @@
3133
import org.opensearch.sql.planner.physical.RenameOperator;
3234
import org.opensearch.sql.planner.physical.SortOperator;
3335
import org.opensearch.sql.planner.physical.TakeOrderedOperator;
36+
import org.opensearch.sql.planner.physical.TrendlineOperator;
3437
import org.opensearch.sql.planner.physical.ValuesOperator;
3538
import org.opensearch.sql.planner.physical.WindowOperator;
3639
import org.opensearch.sql.storage.TableScanOperator;
@@ -211,6 +214,21 @@ public ExplainResponseNode visitNested(NestedOperator node, Object context) {
211214
explanNode -> explanNode.setDescription(ImmutableMap.of("nested", node.getFields())));
212215
}
213216

217+
@Override
218+
public ExplainResponseNode visitTrendline(TrendlineOperator node, Object context) {
219+
return explain(
220+
node,
221+
context,
222+
explainNode ->
223+
explainNode.setDescription(
224+
ImmutableMap.of(
225+
"computations",
226+
describeTrendlineComputations(
227+
node.getComputations().stream()
228+
.map(Pair::getKey)
229+
.collect(Collectors.toList())))));
230+
}
231+
214232
protected ExplainResponseNode explain(
215233
PhysicalPlan node, Object context, Consumer<ExplainResponseNode> doExplain) {
216234
ExplainResponseNode explainNode = new ExplainResponseNode(getOperatorName(node));
@@ -245,4 +263,18 @@ private Map<String, Map<String, String>> describeSortList(
245263
"sortOrder", p.getLeft().getSortOrder().toString(),
246264
"nullOrder", p.getLeft().getNullOrder().toString())));
247265
}
266+
267+
private List<Map<String, String>> describeTrendlineComputations(
268+
List<Trendline.TrendlineComputation> computations) {
269+
return computations.stream()
270+
.map(
271+
computation ->
272+
ImmutableMap.of(
273+
"computationType",
274+
computation.getComputationType().name().toLowerCase(Locale.ROOT),
275+
"numberOfDataPoints", computation.getNumberOfDataPoints().toString(),
276+
"dataField", computation.getDataField().getChild().get(0).toString(),
277+
"alias", computation.getAlias()))
278+
.collect(Collectors.toList());
279+
}
248280
}

core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java

+7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.opensearch.sql.planner.logical.LogicalRemove;
2424
import org.opensearch.sql.planner.logical.LogicalRename;
2525
import org.opensearch.sql.planner.logical.LogicalSort;
26+
import org.opensearch.sql.planner.logical.LogicalTrendline;
2627
import org.opensearch.sql.planner.logical.LogicalValues;
2728
import org.opensearch.sql.planner.logical.LogicalWindow;
2829
import org.opensearch.sql.planner.physical.AggregationOperator;
@@ -39,6 +40,7 @@
3940
import org.opensearch.sql.planner.physical.RenameOperator;
4041
import org.opensearch.sql.planner.physical.SortOperator;
4142
import org.opensearch.sql.planner.physical.TakeOrderedOperator;
43+
import org.opensearch.sql.planner.physical.TrendlineOperator;
4244
import org.opensearch.sql.planner.physical.ValuesOperator;
4345
import org.opensearch.sql.planner.physical.WindowOperator;
4446
import org.opensearch.sql.storage.read.TableScanBuilder;
@@ -166,6 +168,11 @@ public PhysicalPlan visitCloseCursor(LogicalCloseCursor node, C context) {
166168
return new CursorCloseOperator(visitChild(node, context));
167169
}
168170

171+
@Override
172+
public PhysicalPlan visitTrendline(LogicalTrendline plan, C context) {
173+
return new TrendlineOperator(visitChild(plan, context), plan.getComputations());
174+
}
175+
169176
// Called when paging query requested without `FROM` clause only
170177
@Override
171178
public PhysicalPlan visitPaginate(LogicalPaginate plan, C context) {

core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java

+7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import org.opensearch.sql.ast.expression.Literal;
1616
import org.opensearch.sql.ast.tree.RareTopN.CommandType;
1717
import org.opensearch.sql.ast.tree.Sort.SortOption;
18+
import org.opensearch.sql.ast.tree.Trendline;
19+
import org.opensearch.sql.data.type.ExprCoreType;
1820
import org.opensearch.sql.expression.Expression;
1921
import org.opensearch.sql.expression.LiteralExpression;
2022
import org.opensearch.sql.expression.NamedExpression;
@@ -130,6 +132,11 @@ public static LogicalPlan rareTopN(
130132
return new LogicalRareTopN(input, commandType, noOfResults, Arrays.asList(fields), groupByList);
131133
}
132134

135+
public static LogicalTrendline trendline(
136+
LogicalPlan input, Pair<Trendline.TrendlineComputation, ExprCoreType>... computations) {
137+
return new LogicalTrendline(input, Arrays.asList(computations));
138+
}
139+
133140
@SafeVarargs
134141
public LogicalPlan values(List<LiteralExpression>... values) {
135142
return new LogicalValues(Arrays.asList(values));

core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java

+4
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ public R visitAD(LogicalAD plan, C context) {
104104
return visitNode(plan, context);
105105
}
106106

107+
public R visitTrendline(LogicalTrendline plan, C context) {
108+
return visitNode(plan, context);
109+
}
110+
107111
public R visitPaginate(LogicalPaginate plan, C context) {
108112
return visitNode(plan, context);
109113
}

0 commit comments

Comments
 (0)