diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/EmptyResponseUtils.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/EmptyResponseUtils.java index bd42c2e1c5fb..36cb4a70e413 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/EmptyResponseUtils.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/EmptyResponseUtils.java @@ -35,7 +35,6 @@ import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.request.context.QueryContext; import org.apache.pinot.query.QueryEnvironment; -import org.apache.pinot.query.planner.logical.RelToPlanNodeConverter; import org.apache.pinot.spi.data.FieldSpec; import org.apache.pinot.spi.data.Schema; @@ -165,7 +164,7 @@ public static void fillEmptyResponseSchema(boolean useMSE, BrokerResponse respon columnNames[i] = dataTypeField.getName(); ColumnDataType columnDataType; try { - columnDataType = RelToPlanNodeConverter.convertToColumnDataType(dataTypeField.getType()); + columnDataType = ColumnDataType.fromRelDataType(dataTypeField.getType()); } catch (Exception ignored) { columnDataType = ColumnDataType.UNKNOWN; } @@ -182,7 +181,7 @@ public static void fillEmptyResponseSchema(boolean useMSE, BrokerResponse respon // Fill data type with the validated row type when it is available. for (int i = 0; i < numColumns; i++) { try { - columnDataTypes[i] = RelToPlanNodeConverter.convertToColumnDataType(dataTypeFields.get(i).getType()); + columnDataTypes[i] = ColumnDataType.fromRelDataType(dataTypeFields.get(i).getType()); } catch (Exception ignored) { // Ignore exception and keep the type from response } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/catalog/PinotCatalogReader.java b/pinot-common/src/main/java/org/apache/pinot/common/calcite/catalog/PinotCatalogReader.java similarity index 97% rename from pinot-common/src/main/java/org/apache/pinot/common/catalog/PinotCatalogReader.java rename to pinot-common/src/main/java/org/apache/pinot/common/calcite/catalog/PinotCatalogReader.java index c3eac92fbe3e..26d86b39b0f5 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/catalog/PinotCatalogReader.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/calcite/catalog/PinotCatalogReader.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.common.catalog; +package org.apache.pinot.common.calcite.catalog; import com.google.common.collect.ImmutableList; import java.util.List; diff --git a/pinot-common/src/main/java/org/apache/pinot/common/catalog/PinotNameMatcher.java b/pinot-common/src/main/java/org/apache/pinot/common/calcite/catalog/PinotNameMatcher.java similarity index 98% rename from pinot-common/src/main/java/org/apache/pinot/common/catalog/PinotNameMatcher.java rename to pinot-common/src/main/java/org/apache/pinot/common/calcite/catalog/PinotNameMatcher.java index 42feb9da9b3f..c71a7cfa4be6 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/catalog/PinotNameMatcher.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/calcite/catalog/PinotNameMatcher.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.common.catalog; +package org.apache.pinot.common.calcite.catalog; import java.util.LinkedHashSet; import java.util.List; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java b/pinot-common/src/main/java/org/apache/pinot/common/calcite/function/PinotOperatorTable.java similarity index 98% rename from pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java rename to pinot-common/src/main/java/org/apache/pinot/common/calcite/function/PinotOperatorTable.java index ef3ea9c78f71..e9251304f14b 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/calcite/function/PinotOperatorTable.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.calcite.sql.fun; +package org.apache.pinot.common.calcite.function; import com.google.common.base.Preconditions; import com.google.common.base.Suppliers; @@ -54,7 +54,9 @@ /** - * This class defines all the {@link SqlOperator}s allowed by Pinot. + * This class defines all the Calcite {@link SqlOperator}s allowed by Pinot. This is primarily used by the multi-stage + * query engine during query parsing and validation. + * *
It contains the following types of operators: *
+ * See {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#decimalOf}.
+ *
+ * {@link RequestUtils#getLiteralExpression(SqlLiteral)}
+ * @param relDataType the DECIMAL rel data type.
+ * @param isArray
+ * @return converted {@link ColumnDataType}.
+ *
+ */
+ private static ColumnDataType resolveDecimal(RelDataType relDataType, boolean isArray) {
+ int precision = relDataType.getPrecision();
+ int scale = relDataType.getScale();
+ if (scale == 0) {
+ if (precision <= 10) {
+ return isArray ? ColumnDataType.INT_ARRAY : ColumnDataType.INT;
+ } else if (precision <= 38) {
+ return isArray ? ColumnDataType.LONG_ARRAY : ColumnDataType.LONG;
+ } else {
+ return isArray ? ColumnDataType.DOUBLE_ARRAY : ColumnDataType.BIG_DECIMAL;
+ }
+ } else {
+ // NOTE: Do not use FLOAT to represent DECIMAL to be consistent with single-stage engine behavior.
+ // See {@link RequestUtils#getLiteralExpression(SqlLiteral)}.
+ if (precision <= 30) {
+ return isArray ? ColumnDataType.DOUBLE_ARRAY : ColumnDataType.DOUBLE;
+ } else {
+ return isArray ? ColumnDataType.DOUBLE_ARRAY : ColumnDataType.BIG_DECIMAL;
+ }
+ }
+ }
+
public abstract RelDataType toType(RelDataTypeFactory typeFactory);
}
}
diff --git a/pinot-common/src/test/java/org/apache/pinot/common/calcite/function/SseExpressionTypeInferenceTest.java b/pinot-common/src/test/java/org/apache/pinot/common/calcite/function/SseExpressionTypeInferenceTest.java
new file mode 100644
index 000000000000..28a1f8096b8f
--- /dev/null
+++ b/pinot-common/src/test/java/org/apache/pinot/common/calcite/function/SseExpressionTypeInferenceTest.java
@@ -0,0 +1,297 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.calcite.function;
+
+import java.math.BigDecimal;
+import java.util.List;
+import org.apache.pinot.common.function.FunctionRegistry;
+import org.apache.pinot.common.request.Expression;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.common.utils.request.RequestUtils;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.data.Schema;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertThrows;
+
+
+public class SseExpressionTypeInferenceTest {
+
+ private static Schema buildTestSchema() {
+ return new Schema.SchemaBuilder()
+ .addSingleValueDimension("svInt", FieldSpec.DataType.INT)
+ .addSingleValueDimension("svLong", FieldSpec.DataType.LONG)
+ .addSingleValueDimension("svFloat", FieldSpec.DataType.FLOAT)
+ .addSingleValueDimension("svDouble", FieldSpec.DataType.DOUBLE)
+ .addSingleValueDimension("svBigDecimal", FieldSpec.DataType.BIG_DECIMAL)
+ .addSingleValueDimension("svString", FieldSpec.DataType.STRING)
+ .addSingleValueDimension("svBoolean", FieldSpec.DataType.BOOLEAN)
+ .addSingleValueDimension("svTimestamp", FieldSpec.DataType.TIMESTAMP)
+ .addSingleValueDimension("svBytes", FieldSpec.DataType.BYTES)
+ .addMultiValueDimension("mvInt", FieldSpec.DataType.INT)
+ .addMultiValueDimension("mvLong", FieldSpec.DataType.LONG)
+ .addMultiValueDimension("mvFloat", FieldSpec.DataType.FLOAT)
+ .addMultiValueDimension("mvDouble", FieldSpec.DataType.DOUBLE)
+ .addMultiValueDimension("mvString", FieldSpec.DataType.STRING)
+ .addMultiValueDimension("mvBoolean", FieldSpec.DataType.BOOLEAN)
+ .addMultiValueDimension("mvTimestamp", FieldSpec.DataType.TIMESTAMP)
+ .addMultiValueDimension("mvBytes", FieldSpec.DataType.BYTES)
+ .build();
+ }
+
+ @Test
+ public void testInferExpressionTypeIdentifiers() {
+ Schema schema = buildTestSchema();
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("svInt"), schema),
+ DataSchema.ColumnDataType.INT);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("svLong"), schema),
+ DataSchema.ColumnDataType.LONG);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("svFloat"), schema),
+ DataSchema.ColumnDataType.FLOAT);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("svDouble"), schema),
+ DataSchema.ColumnDataType.DOUBLE);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("svBigDecimal"), schema),
+ DataSchema.ColumnDataType.BIG_DECIMAL);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("svString"), schema),
+ DataSchema.ColumnDataType.STRING);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("svBoolean"), schema),
+ DataSchema.ColumnDataType.BOOLEAN);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("svTimestamp"), schema),
+ DataSchema.ColumnDataType.TIMESTAMP);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("svBytes"), schema),
+ DataSchema.ColumnDataType.BYTES);
+
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("mvInt"), schema),
+ DataSchema.ColumnDataType.INT_ARRAY);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("mvLong"), schema),
+ DataSchema.ColumnDataType.LONG_ARRAY);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("mvFloat"), schema),
+ DataSchema.ColumnDataType.FLOAT_ARRAY);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("mvDouble"), schema),
+ DataSchema.ColumnDataType.DOUBLE_ARRAY);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("mvString"), schema),
+ DataSchema.ColumnDataType.STRING_ARRAY);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("mvBoolean"), schema),
+ DataSchema.ColumnDataType.BOOLEAN_ARRAY);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("mvTimestamp"), schema),
+ DataSchema.ColumnDataType.TIMESTAMP_ARRAY);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("mvBytes"), schema),
+ DataSchema.ColumnDataType.BYTES_ARRAY);
+
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getIdentifierExpression("unknownCol"), schema),
+ DataSchema.ColumnDataType.UNKNOWN);
+ }
+
+ @Test
+ public void testInferExpressionTypeLiterals() {
+ Schema schema = buildTestSchema();
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(true), schema),
+ DataSchema.ColumnDataType.BOOLEAN);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(1), schema),
+ DataSchema.ColumnDataType.INT);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(1L), schema),
+ DataSchema.ColumnDataType.LONG);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(1.0f), schema),
+ DataSchema.ColumnDataType.FLOAT);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(1.0d), schema),
+ DataSchema.ColumnDataType.DOUBLE);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(new BigDecimal("123.45")),
+ schema), DataSchema.ColumnDataType.BIG_DECIMAL);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression("abc"), schema),
+ DataSchema.ColumnDataType.STRING);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(new byte[]{1, 2}), schema),
+ DataSchema.ColumnDataType.BYTES);
+
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(new int[]{1, 2}), schema),
+ DataSchema.ColumnDataType.INT_ARRAY);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(new long[]{1L, 2L}), schema),
+ DataSchema.ColumnDataType.LONG_ARRAY);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(new float[]{1.0f, 2.0f}),
+ schema), DataSchema.ColumnDataType.FLOAT_ARRAY);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(new double[]{1.0d, 2.0d}),
+ schema), DataSchema.ColumnDataType.DOUBLE_ARRAY);
+ assertEquals(
+ SseExpressionTypeInference.inferReturnRelType(RequestUtils.getLiteralExpression(new String[]{"a", "b"}),
+ schema), DataSchema.ColumnDataType.STRING_ARRAY);
+ }
+
+ @Test
+ public void testInferExpressionTypeScalarFunction() {
+ Schema schema = buildTestSchema();
+ FunctionRegistry.init();
+
+ // abs(doubleCol) -> DOUBLE
+ Expression absOnInt = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("abs"), RequestUtils.getIdentifierExpression("svDouble"));
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(absOnInt, schema), DataSchema.ColumnDataType.DOUBLE);
+
+ // intDiv(10.0, 3.0) -> LONG
+ Expression intDiv = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("intDiv"),
+ RequestUtils.getLiteralExpression(10.0d),
+ RequestUtils.getLiteralExpression(3.0d));
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(intDiv, schema), DataSchema.ColumnDataType.LONG);
+
+ // jsonExtractScalar(stringCol, '$.name', 'INT') -> INT
+ Expression jsonExtractScalar = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("jsonExtractScalar"),
+ RequestUtils.getIdentifierExpression("svString"),
+ RequestUtils.getLiteralExpression("$.name"),
+ RequestUtils.getLiteralExpression("INT"));
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(jsonExtractScalar, schema),
+ DataSchema.ColumnDataType.INT);
+
+ // jsonExtractScalar(stringCol, '$.name', 'DOUBLE', '3') -> DOUBLE
+ Expression jsonExtractScalarWithDefault = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("jsonExtractScalar"),
+ RequestUtils.getIdentifierExpression("svString"),
+ RequestUtils.getLiteralExpression("$.name"),
+ RequestUtils.getLiteralExpression("DOUBLE"),
+ RequestUtils.getLiteralExpression("3"));
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(jsonExtractScalarWithDefault, schema),
+ DataSchema.ColumnDataType.DOUBLE);
+ }
+
+ @Test
+ public void testInferExpressionTypeScalarFunctionPolymorphic() {
+ Schema schema = buildTestSchema();
+ FunctionRegistry.init();
+
+ // plus(long, long) -> LONG
+ Expression plusLongLong = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("plus"),
+ RequestUtils.getIdentifierExpression("svLong"),
+ RequestUtils.getIdentifierExpression("svLong"));
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(plusLongLong, schema), DataSchema.ColumnDataType.LONG);
+
+ // plus(long, double) -> DOUBLE
+ Expression plusLongDouble = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("plus"),
+ RequestUtils.getIdentifierExpression("svLong"),
+ RequestUtils.getIdentifierExpression("svDouble"));
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(plusLongDouble, schema),
+ DataSchema.ColumnDataType.DOUBLE);
+ }
+
+ @Test
+ public void testInferExpressionTypeNestedFunctions() {
+ Schema schema = buildTestSchema();
+ FunctionRegistry.init();
+
+ // ln(abs(3.0)) -> DOUBLE
+ Expression absExpression = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("abs"), RequestUtils.getLiteralExpression(3.0d));
+ Expression lnAbs = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("ln"), absExpression);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(lnAbs, schema), DataSchema.ColumnDataType.DOUBLE);
+
+ // reverse(lower(CAST(intCol AS STRING))) -> STRING
+ Expression castToString = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("cast"),
+ RequestUtils.getIdentifierExpression("svInt"),
+ RequestUtils.getLiteralExpression("INT"));
+ Expression trimExpression = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("lower"), castToString);
+ Expression reverseExpression = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("reverse"), trimExpression);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(reverseExpression, schema),
+ DataSchema.ColumnDataType.STRING);
+
+ // cast(plus(intCol, 10) AS DOUBLE) -> DOUBLE
+ Expression plusExpression = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("plus"),
+ RequestUtils.getIdentifierExpression("svInt"),
+ RequestUtils.getLiteralExpression(10));
+ Expression castAsDouble = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("cast"),
+ plusExpression,
+ RequestUtils.getLiteralExpression("DOUBLE"));
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(plusExpression, schema), DataSchema.ColumnDataType.INT);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(castAsDouble, schema),
+ DataSchema.ColumnDataType.DOUBLE);
+ }
+
+ @Test
+ public void testInferExpressionTypeUnknownFunction() {
+ Schema schema = buildTestSchema();
+ FunctionRegistry.init();
+
+ // Unknown function
+ Expression unknownFn = RequestUtils.getFunctionExpression("notafunction",
+ List.of(RequestUtils.getIdentifierExpression("svInt")));
+ assertThrows(IllegalArgumentException.class,
+ () -> SseExpressionTypeInference.inferReturnRelType(unknownFn, schema));
+ }
+
+ @Test
+ public void testInferExpressionTypeAggregationFunction() {
+ Schema schema = buildTestSchema();
+ FunctionRegistry.init();
+
+ // sum(intCol) -> LONG
+ Expression sumExpression = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("sum"),
+ RequestUtils.getIdentifierExpression("svInt"));
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(sumExpression, schema), DataSchema.ColumnDataType.LONG);
+
+ // min(doubleCol) -> DOUBLE
+ Expression minExpression = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("min"),
+ RequestUtils.getIdentifierExpression("svDouble"));
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(minExpression, schema),
+ DataSchema.ColumnDataType.DOUBLE);
+
+ // max(reverse(stringCol)) -> STRING
+ Expression reverseExpression = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("reverse"),
+ RequestUtils.getIdentifierExpression("svString"));
+ Expression maxExpression = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("max"),
+ reverseExpression);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(maxExpression, schema),
+ DataSchema.ColumnDataType.STRING);
+
+ // reverse(max(stringCol)) -> STRING
+ Expression maxStringExpression = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("max"),
+ RequestUtils.getIdentifierExpression("svString"));
+ Expression reverseMaxExpression = RequestUtils.getFunctionExpression(
+ FunctionRegistry.canonicalize("reverse"),
+ maxStringExpression);
+ assertEquals(SseExpressionTypeInference.inferReturnRelType(reverseMaxExpression, schema),
+ DataSchema.ColumnDataType.STRING);
+ }
+}
diff --git a/pinot-common/src/test/java/org/apache/pinot/common/utils/DataSchemaTest.java b/pinot-common/src/test/java/org/apache/pinot/common/utils/DataSchemaTest.java
index 3a22b1d30cc1..f49ab925a452 100644
--- a/pinot-common/src/test/java/org/apache/pinot/common/utils/DataSchemaTest.java
+++ b/pinot-common/src/test/java/org/apache/pinot/common/utils/DataSchemaTest.java
@@ -21,6 +21,12 @@
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.sql.Timestamp;
+import org.apache.calcite.rel.type.RelDataTypeSystem;
+import org.apache.calcite.sql.SqlIdentifier;
+import org.apache.calcite.sql.type.ArraySqlType;
+import org.apache.calcite.sql.type.BasicSqlType;
+import org.apache.calcite.sql.type.ObjectSqlType;
+import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.utils.BytesUtils;
import org.testng.Assert;
@@ -189,4 +195,104 @@ public void testColumnDataType() {
byte[] bytesValue = {12, 34, 56};
Assert.assertEquals(BYTES.format(bytesValue), BytesUtils.toHexString(bytesValue));
}
+
+ @Test
+ public void testConvertFromRelDataTypeToColumnDataTypeForObjectTypes() {
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.BOOLEAN, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.BOOLEAN);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.TINYINT, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.INT);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.SMALLINT, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.INT);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.INTEGER, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.INT);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.BIGINT, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.LONG);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.FLOAT, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.FLOAT);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.DOUBLE, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.DOUBLE);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.TIMESTAMP, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.TIMESTAMP);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.CHAR, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.STRING);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.VARCHAR, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.STRING);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.VARBINARY, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.BYTES);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ObjectSqlType(SqlTypeName.OTHER, SqlIdentifier.STAR, true, null, null)),
+ DataSchema.ColumnDataType.OBJECT);
+ }
+
+ @Test
+ public void testBigDecimal() {
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 10)),
+ DataSchema.ColumnDataType.INT);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 38)),
+ DataSchema.ColumnDataType.LONG);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 39)),
+ DataSchema.ColumnDataType.BIG_DECIMAL);
+
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 14, 10)),
+ DataSchema.ColumnDataType.DOUBLE);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 30, 10)),
+ DataSchema.ColumnDataType.DOUBLE);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 31, 10)),
+ DataSchema.ColumnDataType.BIG_DECIMAL);
+ }
+
+ @Test
+ public void testConvertToColumnDataTypeForArray() {
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ArraySqlType(new ObjectSqlType(SqlTypeName.BOOLEAN, SqlIdentifier.STAR, true, null, null), true)),
+ DataSchema.ColumnDataType.BOOLEAN_ARRAY);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ArraySqlType(new ObjectSqlType(SqlTypeName.TINYINT, SqlIdentifier.STAR, true, null, null), true)),
+ DataSchema.ColumnDataType.INT_ARRAY);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ArraySqlType(new ObjectSqlType(SqlTypeName.SMALLINT, SqlIdentifier.STAR, true, null, null), true)),
+ DataSchema.ColumnDataType.INT_ARRAY);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ArraySqlType(new ObjectSqlType(SqlTypeName.INTEGER, SqlIdentifier.STAR, true, null, null), true)),
+ DataSchema.ColumnDataType.INT_ARRAY);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ArraySqlType(new ObjectSqlType(SqlTypeName.BIGINT, SqlIdentifier.STAR, true, null, null), true)),
+ DataSchema.ColumnDataType.LONG_ARRAY);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ArraySqlType(new ObjectSqlType(SqlTypeName.FLOAT, SqlIdentifier.STAR, true, null, null), true)),
+ DataSchema.ColumnDataType.FLOAT_ARRAY);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ArraySqlType(new ObjectSqlType(SqlTypeName.DOUBLE, SqlIdentifier.STAR, true, null, null), true)),
+ DataSchema.ColumnDataType.DOUBLE_ARRAY);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ArraySqlType(new ObjectSqlType(SqlTypeName.TIMESTAMP, SqlIdentifier.STAR, true, null, null), true)),
+ DataSchema.ColumnDataType.TIMESTAMP_ARRAY);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ArraySqlType(new ObjectSqlType(SqlTypeName.CHAR, SqlIdentifier.STAR, true, null, null), true)),
+ DataSchema.ColumnDataType.STRING_ARRAY);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ArraySqlType(new ObjectSqlType(SqlTypeName.VARCHAR, SqlIdentifier.STAR, true, null, null), true)),
+ DataSchema.ColumnDataType.STRING_ARRAY);
+ Assert.assertEquals(DataSchema.ColumnDataType.fromRelDataType(
+ new ArraySqlType(new ObjectSqlType(SqlTypeName.VARBINARY, SqlIdentifier.STAR, true, null, null), true)),
+ DataSchema.ColumnDataType.BYTES_ARRAY);
+ }
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/statement/AggregateFunctionRewriteOptimizer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/statement/AggregateFunctionRewriteOptimizer.java
index e023c820135b..a3adf0735c12 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/statement/AggregateFunctionRewriteOptimizer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/statement/AggregateFunctionRewriteOptimizer.java
@@ -20,18 +20,24 @@
import java.util.List;
import javax.annotation.Nullable;
+import org.apache.pinot.common.calcite.function.SseExpressionTypeInference;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.PinotQuery;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.segment.spi.AggregationFunctionType;
-import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
/**
* Rewrites aggregate functions to type-specific versions in order to support polymorphic functions.
*/
public class AggregateFunctionRewriteOptimizer implements StatementOptimizer {
+ public static final Logger LOGGER = LoggerFactory.getLogger(AggregateFunctionRewriteOptimizer.class);
+
@Override
public void optimize(PinotQuery pinotQuery, @Nullable Schema schema) {
if (schema == null) {
@@ -74,24 +80,26 @@ private void maybeRewriteAggregateFunction(@Nullable Expression expression, Sche
return;
}
- // Rewrite MIN(stringCol) and MAX(stringCol) to MINSTRING / MAXSTRING
+ // Rewrite MIN(stringVal) and MAX(stringVal) to MINSTRING / MAXSTRING
if ((functionName.equals(AggregationFunctionType.MIN.getName())
|| functionName.equals(AggregationFunctionType.MAX.getName()))
&& function.getOperandsSize() == 1) {
Expression operand = function.getOperands().get(0);
- // TODO: Handle more complex expressions (e.g. MIN(trim(stringCol)) )
- if (operand.isSetIdentifier()) {
- String columnName = operand.getIdentifier().getName();
- if (schema != null) {
- FieldSpec fieldSpec = schema.getFieldSpecFor(columnName);
- if (fieldSpec != null && fieldSpec.getDataType().getStoredType() == FieldSpec.DataType.STRING) {
- String newFunctionName =
- functionName.equals(AggregationFunctionType.MIN.getName())
- ? AggregationFunctionType.MINSTRING.name().toLowerCase()
- : AggregationFunctionType.MAXSTRING.name().toLowerCase();
- function.setOperator(newFunctionName);
- }
- }
+
+ ColumnDataType dataType;
+ try {
+ dataType = SseExpressionTypeInference.inferReturnRelType(operand, schema);
+ } catch (Exception e) {
+ // Ignore exceptions during type inference and do not rewrite the function
+ LOGGER.warn("Exception while inferring return type for expression: {}", operand, e);
+ return;
+ }
+ if (dataType.getStoredType() == ColumnDataType.STRING) {
+ String newFunctionName =
+ functionName.equals(AggregationFunctionType.MIN.getName())
+ ? AggregationFunctionType.MINSTRING.name().toLowerCase()
+ : AggregationFunctionType.MAXSTRING.name().toLowerCase();
+ function.setOperator(newFunctionName);
}
}
}
diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
index 0a774238f838..4b830806ded3 100644
--- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
+++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
@@ -330,6 +330,11 @@ private void testHardcodedQueriesCommon()
query = "SELECT MIN(OriginCityName), MAX(OriginCityName) FROM mytable";
testQuery(query);
+ // Verify that rewrite also works when input operand is a complex expression (with non-canonical function names)
+ query = "SELECT MIN(SUB_STRING(TRIM(OriginCityName), 1)), max(trim(OriginCityName)) FROM mytable";
+ h2Query = "SELECT MIN(SUBSTRING(TRIM(OriginCityName), 1)), MAX(TRIM(OriginCityName)) FROM mytable";
+ testQuery(query, h2Query);
+
// Test orderedPreferredPools option which will fallbacks to non preferred Pools
// when non of preferred Pools is available
query = "SELECT count(*) FROM mytable WHERE OriginState LIKE 'A_' option(orderedPreferredPools=0|1)";
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java
index fea69ab5d211..62295c008af1 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java
@@ -48,7 +48,6 @@
import org.apache.pinot.common.function.FunctionRegistry;
import org.apache.pinot.common.function.QueryFunctionInvoker;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
-import org.apache.pinot.query.planner.logical.RelToPlanNodeConverter;
import org.apache.pinot.spi.utils.TimestampUtils;
import org.apache.pinot.sql.parsers.SqlCompilationException;
@@ -174,7 +173,7 @@ private static RexNode evaluateLiteralOnlyFunction(RexCall rexCall, RexBuilder r
// Function operands cannot be evaluated, skip
return rexCall;
}
- argumentTypes[i] = RelToPlanNodeConverter.convertToColumnDataType(rexLiteral.getType());
+ argumentTypes[i] = ColumnDataType.fromRelDataType(rexLiteral.getType());
arguments[i] = getLiteralValue(rexLiteral);
}
@@ -183,7 +182,7 @@ private static RexNode evaluateLiteralOnlyFunction(RexCall rexCall, RexBuilder r
// to is determined by the operator's return type. Pinot's CAST function implementation requires two arguments:
// the value to be cast and the target type.
argumentTypes = new ColumnDataType[]{argumentTypes[0], ColumnDataType.STRING};
- arguments = new Object[]{arguments[0], RelToPlanNodeConverter.convertToColumnDataType(rexCall.getType()).name()};
+ arguments = new Object[]{arguments[0], ColumnDataType.fromRelDataType(rexCall.getType()).name()};
}
String canonicalName = FunctionRegistry.canonicalize(PinotRuleUtils.extractFunctionName(rexCall));
FunctionInfo functionInfo = FunctionRegistry.lookupFunctionInfo(canonicalName, argumentTypes);
@@ -253,7 +252,7 @@ private static RexNode evaluateLiteralOnlyFunction(RexCall rexCall, RexBuilder r
private static RelDataType convertDecimalType(RelDataType relDataType, RexBuilder rexBuilder) {
Preconditions.checkArgument(relDataType.getSqlTypeName() == SqlTypeName.DECIMAL);
- return RelToPlanNodeConverter.convertToColumnDataType(relDataType).toType(rexBuilder.getTypeFactory());
+ return ColumnDataType.fromRelDataType(relDataType).toType(rexBuilder.getTypeFactory());
}
@Nullable
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRule.java
index 15b18317c3aa..b86250b7bac7 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRule.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRule.java
@@ -33,8 +33,8 @@
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.pinot.calcite.rel.logical.PinotLogicalSortExchange;
+import org.apache.pinot.common.calcite.type.TypeFactory;
import org.apache.pinot.query.planner.logical.RexExpressionUtils;
-import org.apache.pinot.query.type.TypeFactory;
public class PinotSortExchangeCopyRule extends RelRule