Skip to content

Commit 60f81a9

Browse files
committed
Add rule for unwrapping TIMESTAMP to DATE cast when comparing with DATE literal
1 parent 5922103 commit 60f81a9

File tree

4 files changed

+370
-0
lines changed

4 files changed

+370
-0
lines changed

core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@
230230
import io.trino.sql.planner.iterative.rule.UnwrapCastInComparison;
231231
import io.trino.sql.planner.iterative.rule.UnwrapRowSubscript;
232232
import io.trino.sql.planner.iterative.rule.UnwrapSingleColumnRowInApply;
233+
import io.trino.sql.planner.iterative.rule.UnwrapTimestampToDateCastInComparison;
233234
import io.trino.sql.planner.optimizations.AddExchanges;
234235
import io.trino.sql.planner.optimizations.AddLocalExchanges;
235236
import io.trino.sql.planner.optimizations.BeginTableWrite;
@@ -356,6 +357,7 @@ public PlanOptimizers(
356357
.addAll(new SimplifyExpressions(plannerContext, typeAnalyzer).rules())
357358
.addAll(new UnwrapRowSubscript().rules())
358359
.addAll(new PushCastIntoRow().rules())
360+
.addAll(new UnwrapTimestampToDateCastInComparison(plannerContext, typeAnalyzer).rules())
359361
.addAll(new UnwrapCastInComparison(plannerContext, typeAnalyzer).rules())
360362
.addAll(new RemoveDuplicateConditions(metadata).rules())
361363
.addAll(new CanonicalizeExpressions(plannerContext, typeAnalyzer).rules())
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.sql.planner.iterative.rule;
15+
16+
import io.trino.Session;
17+
import io.trino.metadata.OperatorNotFoundException;
18+
import io.trino.metadata.ResolvedFunction;
19+
import io.trino.spi.TrinoException;
20+
import io.trino.spi.type.DateType;
21+
import io.trino.spi.type.TimestampType;
22+
import io.trino.spi.type.Type;
23+
import io.trino.sql.InterpretedFunctionInvoker;
24+
import io.trino.sql.PlannerContext;
25+
import io.trino.sql.planner.ExpressionInterpreter;
26+
import io.trino.sql.planner.LiteralEncoder;
27+
import io.trino.sql.planner.NoOpSymbolResolver;
28+
import io.trino.sql.planner.TypeAnalyzer;
29+
import io.trino.sql.planner.TypeProvider;
30+
import io.trino.sql.tree.Cast;
31+
import io.trino.sql.tree.ComparisonExpression;
32+
import io.trino.sql.tree.Expression;
33+
import io.trino.sql.tree.ExpressionTreeRewriter;
34+
import io.trino.sql.tree.IsNullPredicate;
35+
import io.trino.sql.tree.NullLiteral;
36+
37+
import java.util.Optional;
38+
39+
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
40+
import static io.trino.spi.type.DateType.DATE;
41+
import static io.trino.sql.ExpressionUtils.and;
42+
import static io.trino.sql.ExpressionUtils.or;
43+
import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
44+
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN;
45+
import static java.util.Objects.requireNonNull;
46+
47+
/**
48+
* Rewrites CAST(ts_column as DATE) OP date_literal to range expression on ts_column. Dropping cast
49+
* allows for further optimizations, such as pushdown into connectors.
50+
* <p>
51+
* TODO: replace with more general mechanism supporting broader range of types
52+
*
53+
* @see io.trino.sql.planner.iterative.rule.UnwrapCastInComparison
54+
*/
55+
public class UnwrapTimestampToDateCastInComparison
56+
extends ExpressionRewriteRuleSet
57+
{
58+
public UnwrapTimestampToDateCastInComparison(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer)
59+
{
60+
super(createRewrite(plannerContext, typeAnalyzer));
61+
}
62+
63+
private static ExpressionRewriter createRewrite(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer)
64+
{
65+
requireNonNull(plannerContext, "plannerContext is null");
66+
requireNonNull(typeAnalyzer, "typeAnalyzer is null");
67+
68+
return (expression, context) -> unwrapCasts(context.getSession(), plannerContext, typeAnalyzer, context.getSymbolAllocator().getTypes(), expression);
69+
}
70+
71+
public static Expression unwrapCasts(Session session,
72+
PlannerContext plannerContext,
73+
TypeAnalyzer typeAnalyzer,
74+
TypeProvider types,
75+
Expression expression)
76+
{
77+
return ExpressionTreeRewriter.rewriteWith(new Visitor(plannerContext, typeAnalyzer, session, types), expression);
78+
}
79+
80+
private static class Visitor
81+
extends io.trino.sql.tree.ExpressionRewriter<Void>
82+
{
83+
private final PlannerContext plannerContext;
84+
private final TypeAnalyzer typeAnalyzer;
85+
private final Session session;
86+
private final TypeProvider types;
87+
private final InterpretedFunctionInvoker functionInvoker;
88+
private final LiteralEncoder literalEncoder;
89+
90+
public Visitor(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, Session session, TypeProvider types)
91+
{
92+
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
93+
this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null");
94+
this.session = requireNonNull(session, "session is null");
95+
this.types = requireNonNull(types, "types is null");
96+
this.functionInvoker = new InterpretedFunctionInvoker(plannerContext.getMetadata());
97+
this.literalEncoder = new LiteralEncoder(plannerContext);
98+
}
99+
100+
@Override
101+
public Expression rewriteComparisonExpression(ComparisonExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
102+
{
103+
ComparisonExpression expression = (ComparisonExpression) treeRewriter.defaultRewrite((Expression) node, null);
104+
return unwrapCast(expression);
105+
}
106+
107+
private Expression unwrapCast(ComparisonExpression expression)
108+
{
109+
// Canonicalization is handled by CanonicalizeExpressionRewriter
110+
if (!(expression.getLeft() instanceof Cast)) {
111+
return expression;
112+
}
113+
114+
Object right = new ExpressionInterpreter(expression.getRight(), plannerContext, session, typeAnalyzer.getTypes(session, types, expression.getRight()))
115+
.optimize(NoOpSymbolResolver.INSTANCE);
116+
117+
Cast cast = (Cast) expression.getLeft();
118+
ComparisonExpression.Operator operator = expression.getOperator();
119+
120+
if (right == null || right instanceof NullLiteral) {
121+
// handled by general UnwrapCastInComparison
122+
return expression;
123+
}
124+
125+
if (right instanceof Expression) {
126+
return expression;
127+
}
128+
129+
Type sourceType = typeAnalyzer.getType(session, types, cast.getExpression());
130+
Type targetType = typeAnalyzer.getType(session, types, expression.getRight());
131+
132+
if (sourceType instanceof TimestampType && targetType == DATE) {
133+
return unwrapTimestampToDateCast(session, (TimestampType) sourceType, (DateType) targetType, operator, cast.getExpression(), (long) right).orElse(expression);
134+
}
135+
136+
return expression;
137+
}
138+
139+
private Optional<Expression> unwrapTimestampToDateCast(Session session, TimestampType sourceType, DateType targetType, ComparisonExpression.Operator operator, Expression timestampExpression, long date)
140+
{
141+
ResolvedFunction targetToSource;
142+
try {
143+
targetToSource = plannerContext.getMetadata().getCoercion(session, targetType, sourceType);
144+
}
145+
catch (OperatorNotFoundException e) {
146+
throw new TrinoException(GENERIC_INTERNAL_ERROR, e);
147+
}
148+
149+
Expression dateTimestamp = literalEncoder.toExpression(session, coerce(date, targetToSource), sourceType);
150+
Expression nextDateTimestamp = literalEncoder.toExpression(session, coerce(date + 1, targetToSource), sourceType);
151+
152+
switch (operator) {
153+
case EQUAL:
154+
return Optional.of(
155+
and(
156+
new ComparisonExpression(GREATER_THAN_OR_EQUAL, timestampExpression, dateTimestamp),
157+
new ComparisonExpression(LESS_THAN, timestampExpression, nextDateTimestamp)));
158+
case NOT_EQUAL:
159+
return Optional.of(
160+
or(
161+
new ComparisonExpression(LESS_THAN, timestampExpression, dateTimestamp),
162+
new ComparisonExpression(GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp)));
163+
case LESS_THAN:
164+
return Optional.of(new ComparisonExpression(LESS_THAN, timestampExpression, dateTimestamp));
165+
case LESS_THAN_OR_EQUAL:
166+
return Optional.of(new ComparisonExpression(LESS_THAN, timestampExpression, nextDateTimestamp));
167+
case GREATER_THAN:
168+
return Optional.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp));
169+
case GREATER_THAN_OR_EQUAL:
170+
return Optional.of(new ComparisonExpression(GREATER_THAN_OR_EQUAL, timestampExpression, dateTimestamp));
171+
case IS_DISTINCT_FROM:
172+
return Optional.of(
173+
or(
174+
new IsNullPredicate(timestampExpression),
175+
new ComparisonExpression(LESS_THAN, timestampExpression, dateTimestamp),
176+
new ComparisonExpression(GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp)));
177+
}
178+
throw new TrinoException(GENERIC_INTERNAL_ERROR, "Unsupported operator: " + operator);
179+
}
180+
181+
private Object coerce(Object value, ResolvedFunction coercion)
182+
{
183+
return functionInvoker.invoke(coercion, session.toConnectorSession(), value);
184+
}
185+
}
186+
}

0 commit comments

Comments
 (0)