Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
import org.apache.calcite.sql.SqlNumericLiteral;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.function.FunctionRegistry;
import org.apache.pinot.common.function.FunctionUtils;
import org.apache.pinot.common.function.TransformFunctionType;
import org.apache.pinot.common.request.DataSource;
import org.apache.pinot.common.request.Expression;
Expand All @@ -51,7 +54,10 @@
import org.apache.pinot.common.request.Identifier;
import org.apache.pinot.common.request.Literal;
import org.apache.pinot.common.request.PinotQuery;
import org.apache.pinot.common.request.context.LiteralContext;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.spi.utils.BigDecimalUtils;
import org.apache.pinot.spi.utils.BytesUtils;
import org.apache.pinot.spi.utils.CommonConstants.Broker.Request;
Expand Down Expand Up @@ -664,4 +670,49 @@ public static void applyTimestampIndexOverrideHints(
}
}
}

/**
* Infers the expression type by recursively traversing the expression tree for function calls. The provided schema
* is used to resolve the types for identifier expressions. Note that for function calls, only scalar functions are
* supported here currently. Transform functions that don't have equivalent scalar functions, and aggregation
* functions aren't supported and {@link ColumnDataType#UNKNOWN} will be returned for them.
*/
@Nullable
public static ColumnDataType inferExpressionType(@Nullable Expression expression, Schema schema) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to accept null expression here. If we don't accept null expression, the return is also always not null

if (expression == null) {
return null;
}

if (expression.isSetIdentifier()) {
String columnName = expression.getIdentifier().getName();
FieldSpec fieldSpec = schema.getFieldSpecFor(columnName);
if (fieldSpec == null) {
return ColumnDataType.UNKNOWN;
}
return ColumnDataType.fromDataType(fieldSpec.getDataType(), fieldSpec.isSingleValueField());
}

if (expression.isSetLiteral()) {
LiteralContext literalContext = new LiteralContext(expression.getLiteral());
return ColumnDataType.fromDataType(literalContext.getType(), literalContext.isSingleValue());
}

if (expression.isSetFunctionCall()) {
Function fn = expression.getFunctionCall();
int numOperands = fn.getOperandsSize();
ColumnDataType[] argTypes = new ColumnDataType[numOperands];
for (int i = 0; i < numOperands; i++) {
ColumnDataType argType = inferExpressionType(fn.getOperands().get(i), schema);
argTypes[i] = argType != null ? argType : ColumnDataType.UNKNOWN;
}
FunctionInfo functionInfo = FunctionRegistry.lookupFunctionInfo(fn.getOperator(), argTypes);
if (functionInfo != null) {
Class<?> returnClass = functionInfo.getMethod().getReturnType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably not enough, especially when a function returns Object because various types of value can be returned by the same function based on the argument, e.g. JSON_EXTRACT_SCALAR.
Ideally we want to use SqlReturnTypeInference to inference the return type of a function. This info should be available from PinotOperatorTable. All the working functions (at least in MSE) are already registered there

ColumnDataType returnType = FunctionUtils.getColumnDataType(returnClass);
return returnType != null ? returnType : ColumnDataType.UNKNOWN;
}
return ColumnDataType.UNKNOWN;
}
return ColumnDataType.UNKNOWN;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
*/
package org.apache.pinot.common.utils.request;

import java.math.BigDecimal;
import java.util.List;
import java.util.Set;
import org.apache.calcite.sql.SqlDialect;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.pinot.common.function.FunctionRegistry;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.ExpressionType;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.sql.parsers.CalciteSqlParser;
import org.apache.pinot.sql.parsers.PinotSqlType;
import org.apache.pinot.sql.parsers.SqlNodeAndOptions;
Expand Down Expand Up @@ -83,4 +89,168 @@ public void testResolveTableNames(String query, Set<String> expectedSet) {
assertEquals(tableNames, expectedSet);
}
}

private static Schema buildTestSchema() {
return new Schema.SchemaBuilder()
.addSingleValueDimension("svInt", FieldSpec.DataType.INT)
.addSingleValueDimension("svLong", FieldSpec.DataType.LONG)
.addSingleValueDimension("svFloat", FieldSpec.DataType.FLOAT)
.addSingleValueDimension("svDouble", FieldSpec.DataType.DOUBLE)
.addSingleValueDimension("svBigDecimal", FieldSpec.DataType.BIG_DECIMAL)
.addSingleValueDimension("svString", FieldSpec.DataType.STRING)
.addSingleValueDimension("svBoolean", FieldSpec.DataType.BOOLEAN)
.addSingleValueDimension("svTimestamp", FieldSpec.DataType.TIMESTAMP)
.addSingleValueDimension("svBytes", FieldSpec.DataType.BYTES)
.addMultiValueDimension("mvInt", FieldSpec.DataType.INT)
.addMultiValueDimension("mvLong", FieldSpec.DataType.LONG)
.addMultiValueDimension("mvFloat", FieldSpec.DataType.FLOAT)
.addMultiValueDimension("mvDouble", FieldSpec.DataType.DOUBLE)
.addMultiValueDimension("mvString", FieldSpec.DataType.STRING)
.addMultiValueDimension("mvBoolean", FieldSpec.DataType.BOOLEAN)
.addMultiValueDimension("mvTimestamp", FieldSpec.DataType.TIMESTAMP)
.addMultiValueDimension("mvBytes", FieldSpec.DataType.BYTES)
.build();
}

@Test
public void testInferExpressionTypeIdentifiers() {
Schema schema = buildTestSchema();
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("svInt"), schema),
ColumnDataType.INT);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("svLong"), schema),
ColumnDataType.LONG);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("svFloat"), schema),
ColumnDataType.FLOAT);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("svDouble"), schema),
ColumnDataType.DOUBLE);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("svBigDecimal"), schema),
ColumnDataType.BIG_DECIMAL);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("svString"), schema),
ColumnDataType.STRING);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("svBoolean"), schema),
ColumnDataType.BOOLEAN);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("svTimestamp"), schema),
ColumnDataType.TIMESTAMP);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("svBytes"), schema),
ColumnDataType.BYTES);

assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("mvInt"), schema),
ColumnDataType.INT_ARRAY);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("mvLong"), schema),
ColumnDataType.LONG_ARRAY);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("mvFloat"), schema),
ColumnDataType.FLOAT_ARRAY);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("mvDouble"), schema),
ColumnDataType.DOUBLE_ARRAY);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("mvString"), schema),
ColumnDataType.STRING_ARRAY);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("mvBoolean"), schema),
ColumnDataType.BOOLEAN_ARRAY);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("mvTimestamp"), schema),
ColumnDataType.TIMESTAMP_ARRAY);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("mvBytes"), schema),
ColumnDataType.BYTES_ARRAY);

assertEquals(RequestUtils.inferExpressionType(RequestUtils.getIdentifierExpression("unknownCol"), schema),
ColumnDataType.UNKNOWN);
}

@Test
public void testInferExpressionTypeLiterals() {
Schema schema = buildTestSchema();
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(true), schema),
ColumnDataType.BOOLEAN);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(1), schema),
ColumnDataType.INT);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(1L), schema),
ColumnDataType.LONG);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(1.0f), schema),
ColumnDataType.FLOAT);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(1.0d), schema),
ColumnDataType.DOUBLE);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(new BigDecimal("123.45")), schema),
ColumnDataType.BIG_DECIMAL);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression("abc"), schema),
ColumnDataType.STRING);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(new byte[]{1, 2}), schema),
ColumnDataType.BYTES);

assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(new int[]{1, 2}), schema),
ColumnDataType.INT_ARRAY);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(new long[]{1L, 2L}), schema),
ColumnDataType.LONG_ARRAY);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(new float[]{1.0f, 2.0f}), schema),
ColumnDataType.FLOAT_ARRAY);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(new double[]{1.0d, 2.0d}), schema),
ColumnDataType.DOUBLE_ARRAY);
assertEquals(RequestUtils.inferExpressionType(RequestUtils.getLiteralExpression(new String[]{"a", "b"}), schema),
ColumnDataType.STRING_ARRAY);
}

@Test
public void testInferExpressionTypeScalarFunction() {
Schema schema = buildTestSchema();
FunctionRegistry.init();

// abs(x) -> DOUBLE
Expression absOnInt = RequestUtils.getFunctionExpression(
FunctionRegistry.canonicalize("abs"), RequestUtils.getIdentifierExpression("svInt"));
assertEquals(RequestUtils.inferExpressionType(absOnInt, schema), ColumnDataType.DOUBLE);

// intDiv(10.0, 3.0) -> LONG
Expression intDiv = RequestUtils.getFunctionExpression(
FunctionRegistry.canonicalize("intDiv"),
RequestUtils.getLiteralExpression(10.0d),
RequestUtils.getLiteralExpression(3.0d));
assertEquals(RequestUtils.inferExpressionType(intDiv, schema), ColumnDataType.LONG);
}

@Test
public void testInferExpressionTypeScalarFunctionPolymorphic() {
Schema schema = buildTestSchema();
FunctionRegistry.init();

// plus(long, long) -> LONG
Expression plusLongLong = RequestUtils.getFunctionExpression(
FunctionRegistry.canonicalize("plus"),
RequestUtils.getIdentifierExpression("svLong"),
RequestUtils.getIdentifierExpression("svLong"));
assertEquals(RequestUtils.inferExpressionType(plusLongLong, schema), ColumnDataType.LONG);

// plus(long, double) -> DOUBLE
Expression plusLongDouble = RequestUtils.getFunctionExpression(
FunctionRegistry.canonicalize("plus"),
RequestUtils.getIdentifierExpression("svLong"),
RequestUtils.getIdentifierExpression("svDouble"));
assertEquals(RequestUtils.inferExpressionType(plusLongDouble, schema), ColumnDataType.DOUBLE);
}

@Test
public void testInferExpressionTypeNestedFunctions() {
Schema schema = buildTestSchema();
FunctionRegistry.init();

// ln(abs(3.0)) -> DOUBLE
Expression absExpression = RequestUtils.getFunctionExpression(
FunctionRegistry.canonicalize("abs"), RequestUtils.getLiteralExpression(3.0d));
Expression lnAbs = RequestUtils.getFunctionExpression(
FunctionRegistry.canonicalize("ln"), absExpression);
assertEquals(RequestUtils.inferExpressionType(lnAbs, schema), ColumnDataType.DOUBLE);
}

@Test
public void testInferExpressionTypeUnknownAndAggregationFunctions() {
Schema schema = buildTestSchema();
FunctionRegistry.init();

// Unknown function -> UNKNOWN
Expression unknownFn = RequestUtils.getFunctionExpression("notafunction",
List.of(RequestUtils.getIdentifierExpression("svInt")));
assertEquals(RequestUtils.inferExpressionType(unknownFn, schema), ColumnDataType.UNKNOWN);

// Aggregation function or wrong arg count for scalar min -> UNKNOWN
Expression minOneArg = RequestUtils.getFunctionExpression(FunctionRegistry.canonicalize("min"),
List.of(RequestUtils.getIdentifierExpression("svInt")));
assertEquals(RequestUtils.inferExpressionType(minOneArg, schema), ColumnDataType.UNKNOWN);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.PinotQuery;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;

/**
Expand Down Expand Up @@ -79,19 +80,14 @@ private void maybeRewriteAggregateFunction(@Nullable Expression expression, Sche
|| functionName.equals(AggregationFunctionType.MAX.getName()))
&& function.getOperandsSize() == 1) {
Expression operand = function.getOperands().get(0);
// TODO: Handle more complex expressions (e.g. MIN(trim(stringCol)) )
if (operand.isSetIdentifier()) {
String columnName = operand.getIdentifier().getName();
if (schema != null) {
FieldSpec fieldSpec = schema.getFieldSpecFor(columnName);
if (fieldSpec != null && fieldSpec.getDataType().getStoredType() == FieldSpec.DataType.STRING) {
String newFunctionName =
functionName.equals(AggregationFunctionType.MIN.getName())
? AggregationFunctionType.MINSTRING.name().toLowerCase()
: AggregationFunctionType.MAXSTRING.name().toLowerCase();
function.setOperator(newFunctionName);
}
}

ColumnDataType dataType = RequestUtils.inferExpressionType(operand, schema);
if (dataType == ColumnDataType.STRING) {
String newFunctionName =
functionName.equals(AggregationFunctionType.MIN.getName())
? AggregationFunctionType.MINSTRING.name().toLowerCase()
: AggregationFunctionType.MAXSTRING.name().toLowerCase();
function.setOperator(newFunctionName);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ private void testHardcodedQueriesCommon()
query = "SELECT MIN(OriginCityName), MAX(OriginCityName) FROM mytable";
testQuery(query);

// Verify that rewrite also works when input operand is a complex expression (with non-canonical function names)
query = "SELECT MIN(SUB_STRING(TRIM(OriginCityName), 1)), max(trim(OriginCityName)) FROM mytable";
h2Query = "SELECT MIN(SUBSTRING(TRIM(OriginCityName), 1)), MAX(TRIM(OriginCityName)) FROM mytable";
testQuery(query, h2Query);

// Test orderedPreferredPools option which will fallbacks to non preferred Pools
// when non of preferred Pools is available
query = "SELECT count(*) FROM mytable WHERE OriginState LIKE 'A_' option(orderedPreferredPools=0|1)";
Expand Down
Loading