diff --git a/docs/src/main/sphinx/connector/clickhouse.md b/docs/src/main/sphinx/connector/clickhouse.md index 4af9cb75ddf8..f18960363d1c 100644 --- a/docs/src/main/sphinx/connector/clickhouse.md +++ b/docs/src/main/sphinx/connector/clickhouse.md @@ -248,6 +248,9 @@ to the following table: * - `UUID` - `UUID` - +* - `MAP(k, v)` + - `MAP(k, v)` + - ::: No other types are supported. @@ -307,6 +310,9 @@ to the following table: * - `UUID` - `UUID` - +* - `MAP(k, v)` + - `MAP(k, v)` + - ::: No other types are supported. diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java index 515ff2608117..1c302ca14285 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java @@ -16,6 +16,8 @@ import com.clickhouse.data.ClickHouseColumn; import com.clickhouse.data.ClickHouseDataType; import com.clickhouse.data.ClickHouseVersion; +import com.clickhouse.data.value.UnsignedLong; +import com.clickhouse.jdbc.JdbcTypeMapping; import com.google.common.base.Enums; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; @@ -44,7 +46,9 @@ import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.LongReadFunction; import io.trino.plugin.jdbc.LongWriteFunction; +import io.trino.plugin.jdbc.ObjectReadFunction; import io.trino.plugin.jdbc.ObjectWriteFunction; +import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.SliceWriteFunction; @@ -61,22 +65,30 @@ import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; +import io.trino.spi.StandardErrorCode; import io.trino.spi.TrinoException; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.MapHashTables; +import io.trino.spi.block.SqlMap; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ColumnPosition; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.JoinStatistics; +import io.trino.spi.connector.JoinType; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Variable; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; import io.trino.spi.type.Int128; +import io.trino.spi.type.MapType; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; +import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import io.trino.spi.type.VarbinaryType; import io.trino.spi.type.VarcharType; @@ -99,6 +111,7 @@ import java.time.LocalDateTime; import java.time.ZonedDateTime; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -116,6 +129,7 @@ import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.plugin.clickhouse.ClickHouseSessionProperties.isMapStringAsVarchar; import static io.trino.plugin.clickhouse.ClickHouseTableProperties.ENGINE_PROPERTY; @@ -164,6 +178,7 @@ import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.StandardErrorCode.SCHEMA_NOT_EMPTY; +import static io.trino.spi.block.MapValueBuilder.buildMapValue; import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -184,6 +199,7 @@ import static io.trino.spi.type.UuidType.javaUuidToTrinoUuid; import static io.trino.spi.type.UuidType.trinoUuidToJavaUuid; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static java.lang.Float.floatToIntBits; import static java.lang.Float.floatToRawIntBits; import static java.lang.Math.floorDiv; import static java.lang.Math.floorMod; @@ -212,6 +228,7 @@ public class ClickHouseClient private final ConnectorExpressionRewriter connectorExpressionRewriter; private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final TypeOperators typeOperators; private final Type uuidType; private final Type ipAddressType; private final AtomicReference clickHouseVersion = new AtomicReference<>(); @@ -226,6 +243,7 @@ public ClickHouseClient( RemoteQueryModifier queryModifier) { super("\"", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, false); + this.typeOperators = typeManager.getTypeOperators(); this.uuidType = typeManager.getType(new TypeSignature(StandardTypes.UUID)); this.ipAddressType = typeManager.getType(new TypeSignature(StandardTypes.IPADDRESS)); JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); @@ -237,6 +255,13 @@ public ClickHouseClient( .map("$not($is_null(value))").to("value IS NOT NULL") .map("$not(value: boolean)").to("NOT value") .map("$is_null(value)").to("value IS NULL") + .map("$equal(left, right)").to("left = right") + .map("$not_equal(left, right)").to("left <> right") + .map("$less_than(left, right)").to("left < right") + .map("$less_than_or_equal(left, right)").to("left <= right") + .map("$greater_than(left, right)").to("left > right") + .map("$greater_than_or_equal(left, right)").to("left >= right") + .map("$operator$subscript(map, value)").to("map[value]") .build(); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( this.connectorExpressionRewriter, @@ -264,9 +289,36 @@ public Optional implementAggregation(ConnectorSession session, A @Override public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) { + for (ColumnHandle columnHandle : assignments.values()) { + JdbcColumnHandle jdbcColumnHandle = (JdbcColumnHandle) columnHandle; + JdbcTypeHandle typeHandle = jdbcColumnHandle.getJdbcTypeHandle(); + String jdbcTypeName = typeHandle.jdbcTypeName() + .orElseThrow(() -> new TrinoException(JDBC_ERROR, "Type name is missing: " + typeHandle)); + if (jdbcColumnHandle.getColumnType() instanceof VarcharType && + getUnsupportedTypeHandling(session) == CONVERT_TO_VARCHAR && + toColumnMapping(session, jdbcTypeName, typeHandle.jdbcType(), typeHandle.decimalDigits(), typeHandle.columnSize()).isEmpty()) { + // Column is mapped to VARCHAR using unsupported type handling, predicate pushdown may not work properly + return Optional.empty(); + } + } + return connectorExpressionRewriter.rewrite(session, expression, assignments); } + @Override + public Optional implementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + Map leftProjections, + PreparedQuery rightSource, + Map rightProjections, + List joinConditions, + JoinStatistics statistics) + { + return Optional.empty(); + } + @Override public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder) { @@ -636,8 +688,22 @@ public Optional toColumnMapping(ConnectorSession session, Connect return mapping; } + Optional columnMapping = toColumnMapping(session, jdbcTypeName, typeHandle.jdbcType(), typeHandle.decimalDigits(), typeHandle.columnSize()); + if (columnMapping.isEmpty() && getUnsupportedTypeHandling(session) == CONVERT_TO_VARCHAR) { + return mapToUnboundedVarchar(typeHandle); + } + return columnMapping; + } + + private Optional toColumnMapping( + ConnectorSession session, + String typeName, + int jdbcType, + Optional decimalDigits, + Optional columnSize) + { ClickHouseVersion version = getClickHouseServerVersion(session); - ClickHouseColumn column = ClickHouseColumn.of("", jdbcTypeName); + ClickHouseColumn column = ClickHouseColumn.of("", typeName); ClickHouseDataType columnDataType = column.getDataType(); switch (columnDataType) { case Bool: @@ -677,11 +743,13 @@ public Optional toColumnMapping(ConnectorSession session, Connect return Optional.of(varbinaryColumnMapping()); case UUID: return Optional.of(uuidColumnMapping()); + case Map: + return mapColumnMapping(session, column.getKeyInfo(), column.getValueInfo()); default: // no-op } - switch (typeHandle.jdbcType()) { + switch (jdbcType) { case Types.TINYINT: return Optional.of(tinyintColumnMapping()); @@ -706,16 +774,13 @@ public Optional toColumnMapping(ConnectorSession session, Connect return Optional.of(doubleColumnMapping()); case Types.DECIMAL: - int decimalDigits = typeHandle.requiredDecimalDigits(); - int precision = typeHandle.requiredColumnSize(); - ColumnMapping decimalColumnMapping; - if (getDecimalRounding(session) == ALLOW_OVERFLOW && precision > Decimals.MAX_PRECISION) { - int scale = Math.min(decimalDigits, getDecimalDefaultScale(session)); + if (getDecimalRounding(session) == ALLOW_OVERFLOW && columnSize.orElseThrow() > Decimals.MAX_PRECISION) { + int scale = Math.min(decimalDigits.orElseThrow(), getDecimalDefaultScale(session)); decimalColumnMapping = decimalColumnMapping(createDecimalType(Decimals.MAX_PRECISION, scale), getDecimalRoundingMode(session)); } else { - decimalColumnMapping = decimalColumnMapping(createDecimalType(precision, max(decimalDigits, 0))); + decimalColumnMapping = decimalColumnMapping(createDecimalType(columnSize.orElseThrow(), max(decimalDigits.orElseThrow(), 0))); } return Optional.of(ColumnMapping.mapping( decimalColumnMapping.getType(), @@ -730,7 +795,7 @@ public Optional toColumnMapping(ConnectorSession session, Connect case Types.TIMESTAMP: if (columnDataType == ClickHouseDataType.DateTime) { // ClickHouse DateTime does not have sub-second precision - verify(typeHandle.requiredDecimalDigits() == 0, "Expected 0 as timestamp precision, but got %s", typeHandle.requiredDecimalDigits()); + verify(decimalDigits.orElseThrow() == 0, "Expected 0 as timestamp precision, but got %s", decimalDigits.orElseThrow()); return Optional.of(ColumnMapping.longMapping( TIMESTAMP_SECONDS, timestampReadFunction(TIMESTAMP_SECONDS), @@ -742,7 +807,7 @@ public Optional toColumnMapping(ConnectorSession session, Connect case Types.TIMESTAMP_WITH_TIMEZONE: if (columnDataType == ClickHouseDataType.DateTime) { // ClickHouse DateTime does not have sub-second precision - verify(typeHandle.requiredDecimalDigits() == 0, "Expected 0 as timestamp with time zone precision, but got %s", typeHandle.requiredDecimalDigits()); + verify(decimalDigits.orElseThrow() == 0, "Expected 0 as timestamp with time zone precision, but got %s", decimalDigits.orElseThrow()); return Optional.of(ColumnMapping.longMapping( TIMESTAMP_TZ_SECONDS, shortTimestampWithTimeZoneReadFunction(), @@ -750,10 +815,6 @@ public Optional toColumnMapping(ConnectorSession session, Connect } } - if (getUnsupportedTypeHandling(session) == CONVERT_TO_VARCHAR) { - return mapToUnboundedVarchar(typeHandle); - } - return Optional.empty(); } @@ -805,6 +866,12 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type) if (type.equals(uuidType)) { return WriteMapping.sliceMapping("UUID", uuidWriteFunction()); } + if (type instanceof MapType mapType) { + WriteMapping keyMapping = toWriteMapping(session, mapType.getKeyType()); + WriteMapping valueMapping = toWriteMapping(session, mapType.getValueType()); + String dataType = "Map(%s, %s)".formatted(keyMapping.getDataType(), valueMapping.getDataType()); + return WriteMapping.objectMapping(dataType, mapWriteFunction()); + } throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type); } @@ -946,26 +1013,30 @@ private ColumnMapping ipAddressColumnMapping(String clickhouseType) (resultSet, columnIndex) -> { // copied from IpAddressOperators.castFromVarcharToIpAddress byte[] address = InetAddresses.forString(resultSet.getString(columnIndex)).getAddress(); - - byte[] bytes; - if (address.length == 4) { - bytes = new byte[16]; - bytes[10] = (byte) 0xff; - bytes[11] = (byte) 0xff; - arraycopy(address, 0, bytes, 12, 4); - } - else if (address.length == 16) { - bytes = address; - } - else { - throw new TrinoException(GENERIC_INTERNAL_ERROR, "Invalid InetAddress length: " + address.length); - } - + byte[] bytes = parseIpAddressBytes(address); return wrappedBuffer(bytes); }, ipAddressWriteFunction(clickhouseType)); } + private static byte[] parseIpAddressBytes(byte[] address) + { + byte[] parsedBytes; + if (address.length == 4) { + parsedBytes = new byte[16]; + parsedBytes[10] = (byte) 0xff; + parsedBytes[11] = (byte) 0xff; + arraycopy(address, 0, parsedBytes, 12, 4); + } + else if (address.length == 16) { + parsedBytes = address; + } + else { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Invalid InetAddress length: " + address.length); + } + return parsedBytes; + } + private static SliceWriteFunction ipAddressWriteFunction(String clickhouseType) { return new SliceWriteFunction() @@ -1003,6 +1074,120 @@ private static SliceWriteFunction uuidWriteFunction() return (statement, index, value) -> statement.setObject(index, trinoUuidToJavaUuid(value), Types.OTHER); } + private Optional mapColumnMapping(ConnectorSession session, ClickHouseColumn keyColumn, ClickHouseColumn valueColumn) + { + JdbcTypeMapping typeMapping = JdbcTypeMapping.getDefaultMapping(); + Optional keyMapping = toColumnMapping( + session, + keyColumn.getOriginalTypeName(), + typeMapping.toSqlType(keyColumn, Map.of()), + Optional.of(keyColumn.getPrecision()), + Optional.of(keyColumn.getScale())); + Optional valueMapping = toColumnMapping( + session, + valueColumn.getOriginalTypeName(), + typeMapping.toSqlType(valueColumn, Map.of()), + Optional.of(valueColumn.getPrecision()), + Optional.of(valueColumn.getScale())); + if (keyMapping.isEmpty() || valueMapping.isEmpty()) { + return Optional.empty(); + } + + MapType mapType = new MapType( + keyMapping.get().getType(), + valueMapping.get().getType(), + typeOperators); + return Optional.of(ColumnMapping.objectMapping( + mapType, + ObjectReadFunction.of(SqlMap.class, (resultSet, columnIndex) -> { + Object data = resultSet.getObject(columnIndex); + if (!(data instanceof Map mapData)) { + throw new TrinoException(StandardErrorCode.TYPE_MISMATCH, "Expected ClickHouse to return a Map"); + } + + return buildMapValue( + mapType, + mapData.size(), + (keyBuilder, valueBuilder) -> { + for (Object key : mapData.keySet()) { + writeValue(keyMapping.get().getType(), keyBuilder, key); + writeValue(valueMapping.get().getType(), valueBuilder, mapData.get(key)); + } + }); + }), + mapWriteFunction(), + DISABLE_PUSHDOWN)); + } + + private static ObjectWriteFunction mapWriteFunction() + { + return ObjectWriteFunction.of(SqlMap.class, (statement, index, value) -> { + MapType mapType = (MapType) value.getMapType(); + Type keyType = mapType.getKeyType(); + Type valueType = mapType.getValueType(); + + Map mapValue = new HashMap<>(); + for (int position = 0; position < value.getSize(); position++) { + Object keyEntry = keyType.getObjectValue(value.getRawKeyBlock(), position); + Object valueEntry = valueType.getObjectValue(value.getRawValueBlock(), position); + mapValue.put(keyEntry, valueEntry); + } + + statement.setObject(index, mapValue); + }); + } + + private static void writeValue(Type type, BlockBuilder blockBuilder, Object value) + { + if (value == null) { + blockBuilder.appendNull(); + } + else if (type.getJavaType() == long.class) { + switch (value) { + case Float floatValue -> type.writeLong(blockBuilder, floatToIntBits(floatValue)); + case Double doubleValue -> type.writeLong(blockBuilder, Double.doubleToLongBits(doubleValue)); + case Number numberValue -> type.writeLong(blockBuilder, numberValue.longValue()); + case LocalDate dateValue -> type.writeLong(blockBuilder, dateValue.toEpochDay()); + default -> throw new UnsupportedOperationException("Unsupported type for map key or value: " + type); + } + } + else if (type.getJavaType() == double.class) { + type.writeDouble(blockBuilder, ((Number) value).doubleValue()); + } + else if (type.getJavaType() == boolean.class) { + type.writeBoolean(blockBuilder, (boolean) value); + } + else if (type.getJavaType() == io.airlift.slice.Slice.class) { + if (value instanceof InetAddress ipAddressValue) { + byte[] address = parseIpAddressBytes(ipAddressValue.getAddress()); + type.writeSlice(blockBuilder, wrappedBuffer(address)); + } + else if (value instanceof UUID uuidValue) { + type.writeSlice(blockBuilder, javaUuidToTrinoUuid(uuidValue)); + } + else { + type.writeSlice(blockBuilder, utf8Slice(value.toString())); + } + } + else if (type.getJavaType() == Int128.class) { + type.writeObject(blockBuilder, Int128.valueOf(((UnsignedLong) value).bigIntegerValue())); + } + else if (type.getJavaType() == SqlMap.class) { + Map mapValue = (Map) value; + MapType mapType = (MapType) type; + BlockBuilder keyBuilder = mapType.getKeyType().createBlockBuilder(null, 0); + BlockBuilder valueBuilder = mapType.getValueType().createBlockBuilder(null, 0); + for (Map.Entry mapEntry : mapValue.entrySet()) { + writeValue(mapType.getKeyType(), keyBuilder, mapEntry.getKey()); + writeValue(mapType.getValueType(), valueBuilder, mapEntry.getValue()); + } + type.writeObject(blockBuilder, new SqlMap(mapType, MapHashTables.HashBuildMode.DUPLICATE_NOT_CHECKED, keyBuilder.build(), valueBuilder.build())); + } + else { + throw new UnsupportedOperationException("Unsupported type for map key or value: " + type); + } + } + public static boolean supportsPushdown(Variable variable, RewriteContext context) { JdbcTypeHandle typeHandle = ((JdbcColumnHandle) context.getAssignment(variable.getName())) diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseTypeMapping.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseTypeMapping.java index d42a6d9024b4..f742e93d793e 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseTypeMapping.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseTypeMapping.java @@ -15,8 +15,11 @@ import com.google.common.collect.ImmutableList; import io.trino.Session; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.MapType; import io.trino.spi.type.TimeZoneKey; -import io.trino.spi.type.UuidType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.TestingSession; import io.trino.testing.datatype.CreateAndInsertDataSetup; @@ -53,6 +56,7 @@ import static io.trino.spi.type.TimestampType.createTimestampType; import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_SECONDS; import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.UuidType.UUID; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; @@ -1128,8 +1132,8 @@ public void testEnum() public void testUuid() { SqlDataTypeTest.create() - .addRoundTrip("Nullable(UUID)", "NULL", UuidType.UUID, "CAST(NULL AS UUID)") - .addRoundTrip("Nullable(UUID)", "'114514ea-0601-1981-1142-e9b55b0abd6d'", UuidType.UUID, "CAST('114514ea-0601-1981-1142-e9b55b0abd6d' AS UUID)") + .addRoundTrip("Nullable(UUID)", "NULL", UUID, "CAST(NULL AS UUID)") + .addRoundTrip("Nullable(UUID)", "'114514ea-0601-1981-1142-e9b55b0abd6d'", UUID, "CAST('114514ea-0601-1981-1142-e9b55b0abd6d' AS UUID)") .execute(getQueryRunner(), clickhouseCreateAndInsert("default.ck_test_uuid")); SqlDataTypeTest.create() @@ -1165,6 +1169,67 @@ public void testIp() .execute(getQueryRunner(), clickhouseCreateAndTrinoInsert("tpch.test_ip")); } + @Test + public void testMap() + { + addClickhouseCreateAndInsertDataTypeTests(SqlDataTypeTest.create()) + .execute(getQueryRunner(), mapStringAsVarcharSession(), clickhouseCreateAndInsert("tpch.test_map")); + + SqlDataTypeTest.create() + .addRoundTrip("Map(Int32, Bool)", "MAP(ARRAY[42], ARRAY[true])", mapWithValueType(BOOLEAN)) + .addRoundTrip("Map(Int32, Int8)", "MAP(ARRAY[42], ARRAY[TINYINT '0'])", mapWithValueType(TINYINT)) + .addRoundTrip("Map(Int32, Int32)", "MAP(ARRAY[42], ARRAY[1])", mapWithValueType(INTEGER)) + .addRoundTrip("Map(Int32, Int64)", "MAP(ARRAY[42], ARRAY[BIGINT '2'])", mapWithValueType(BIGINT)) + .addRoundTrip("Map(Int32, UInt8)", "MAP(ARRAY[42], ARRAY[SMALLINT '3'])", mapWithValueType(SMALLINT)) + .addRoundTrip("Map(Int32, UInt16)", "MAP(ARRAY[42], ARRAY[4])", mapWithValueType(INTEGER)) + .addRoundTrip("Map(Int32, UInt32)", "MAP(ARRAY[42], ARRAY[BIGINT '5'])", mapWithValueType(BIGINT)) + .addRoundTrip("Map(Int32, UInt64)", "MAP(ARRAY[42], ARRAY[CAST(6 AS DECIMAL(20, 0))])", mapWithValueType(DecimalType.createDecimalType(20, 0))) + .addRoundTrip("Map(Int32, Float32)", "MAP(ARRAY[42], ARRAY[REAL '7'])", mapWithValueType(REAL)) + .addRoundTrip("Map(Int32, Float64)", "MAP(ARRAY[42], ARRAY[DOUBLE '8'])", mapWithValueType(DOUBLE)) + .addRoundTrip("Map(Int32, FixedString(4))", "MAP(ARRAY[42], ARRAY['nine'])", mapWithValueType(VARCHAR), "MAP(ARRAY[42], ARRAY[CAST('nine' AS VARCHAR)])") + .addRoundTrip("Map(Int32, String)", "MAP(ARRAY[42], ARRAY['ten'])", mapWithValueType(VARCHAR), "MAP(ARRAY[42], ARRAY[CAST('ten' AS VARCHAR)])") + .addRoundTrip("Map(Int32, Date)", "MAP(ARRAY[42], ARRAY[DATE '2011-11-11'])", mapWithValueType(DATE)) + .addRoundTrip("Map(Int32, IPv4)", "MAP(ARRAY[42], ARRAY[IPADDRESS '213.213.213.213'])", mapWithValueType(IPADDRESS)) + .addRoundTrip("Map(Int32, IPv6)", "MAP(ARRAY[42], ARRAY[IPADDRESS '2001:44c8:129:2632:33:0:252:14'])", mapWithValueType(IPADDRESS)) + .addRoundTrip("Map(Int32, Enum('hello' = 1, 'world' = 2))", "MAP(ARRAY[42], ARRAY[CAST('world' AS VARCHAR)])", mapWithValueType(VARCHAR)) + .addRoundTrip("Map(Int32, UUID)", "MAP(ARRAY[42], ARRAY[UUID '92d3f742-b13c-4d8e-9d7a-1130d2d31980'])", mapWithValueType(UUID)) + .addRoundTrip("Map(Int32, Map(Int32, Int32))", "MAP(ARRAY[42], ARRAY[MAP(ARRAY[1], ARRAY[2])])", mapWithValueType(mapWithValueType(INTEGER))) + .execute(getQueryRunner(), mapStringAsVarcharSession(), clickhouseCreateAndTrinoInsert(mapStringAsVarcharSession(), "tpch.test_map")); + } + + protected SqlDataTypeTest addClickhouseCreateAndInsertDataTypeTests(SqlDataTypeTest dataTypeTest) + { + dataTypeTest.addRoundTrip("Map(Int32, Bool)", "map(42, true)", mapWithValueType(BOOLEAN), "MAP(ARRAY[42], ARRAY[true])") + .addRoundTrip("Map(Int32, Int8)", "map(42, 0)", mapWithValueType(TINYINT), "MAP(ARRAY[42], ARRAY[TINYINT '0'])") + .addRoundTrip("Map(Int32, Int32)", "map(42, 1)", mapWithValueType(INTEGER), "MAP(ARRAY[42], ARRAY[1])") + .addRoundTrip("Map(Int32, Int64)", "map(42, 2)", mapWithValueType(BIGINT), "MAP(ARRAY[42], ARRAY[BIGINT '2'])") + .addRoundTrip("Map(Int32, UInt8)", "map(42, 3)", mapWithValueType(SMALLINT), "MAP(ARRAY[42], ARRAY[SMALLINT '3'])") + .addRoundTrip("Map(Int32, UInt16)", "map(42, 4)", mapWithValueType(INTEGER), "MAP(ARRAY[42], ARRAY[4])") + .addRoundTrip("Map(Int32, UInt32)", "map(42, 5)", mapWithValueType(BIGINT), "MAP(ARRAY[42], ARRAY[BIGINT '5'])") + .addRoundTrip("Map(Int32, UInt64)", "map(42, 6)", mapWithValueType(DecimalType.createDecimalType(20, 0)), "MAP(ARRAY[42], ARRAY[CAST(6 AS DECIMAL(20, 0))])") + .addRoundTrip("Map(Int32, Float32)", "map(42, 7)", mapWithValueType(REAL), "MAP(ARRAY[42], ARRAY[REAL '7'])") + .addRoundTrip("Map(Int32, Float64)", "map(42, 8)", mapWithValueType(DOUBLE), "MAP(ARRAY[42], ARRAY[DOUBLE '8'])") + .addRoundTrip("Map(Int32, FixedString(4))", "map(42, 'nine')", mapWithValueType(VARCHAR), "MAP(ARRAY[42], ARRAY[CAST('nine' AS VARCHAR)])") + .addRoundTrip("Map(Int32, String)", "map(42, 'ten')", mapWithValueType(VARCHAR), "MAP(ARRAY[42], ARRAY[CAST('ten' AS VARCHAR)])") + .addRoundTrip("Map(Int32, Date)", "map(42, '2011-11-11')", mapWithValueType(DATE), "MAP(ARRAY[42], ARRAY[DATE '2011-11-11'])") + .addRoundTrip("Map(Int32, IPv4)", "map(42, '213.213.213.213')", mapWithValueType(IPADDRESS), "MAP(ARRAY[42], ARRAY[IPADDRESS '213.213.213.213'])") + .addRoundTrip("Map(Int32, IPv6)", "map(42, '2001:44c8:129:2632:33:0:252:14')", mapWithValueType(IPADDRESS), "MAP(ARRAY[42], ARRAY[IPADDRESS '2001:44c8:129:2632:33:0:252:14'])") + .addRoundTrip("Map(Int32, Enum('hello' = 1, 'world' = 2))", "map(42, 'world')", mapWithValueType(VARCHAR), "MAP(ARRAY[42], ARRAY[CAST('world' AS VARCHAR)])") + .addRoundTrip("Map(Int32, UUID)", "map(42, '92d3f742-b13c-4d8e-9d7a-1130d2d31980')", mapWithValueType(UUID), "MAP(ARRAY[42], ARRAY[UUID '92d3f742-b13c-4d8e-9d7a-1130d2d31980'])"); + // TODO: This fails because the timestamp precision is set to 29 +// .addRoundTrip( +// "Map(Int32, DateTime('Asia/Kathmandu'))", +// "map(42, '2012-12-01 12:00:00')", +// mapWithValueType(TIMESTAMP_TZ_SECONDS), +// "MAP(ARRAY[42], ARRAY[TIMESTAMP '2012-12-01 12:00:00 +05:45'])") + return dataTypeTest; + } + + protected static Type mapWithValueType(Type valueType) + { + return new MapType(INTEGER, valueType, new TypeOperators()); + } + @Test public void testUnsupportedPoint() { @@ -1215,6 +1280,11 @@ protected DataSetup clickhouseCreateAndTrinoInsert(String tableNamePrefix) return new CreateAndTrinoInsertDataSetup(new ClickHouseSqlExecutor(onRemoteDatabase()), new TrinoSqlExecutor(getQueryRunner()), tableNamePrefix); } + protected DataSetup clickhouseCreateAndTrinoInsert(Session session, String tableNamePrefix) + { + return new CreateAndTrinoInsertDataSetup(new ClickHouseSqlExecutor(onRemoteDatabase()), new TrinoSqlExecutor(getQueryRunner(), session), tableNamePrefix); + } + protected SqlExecutor onRemoteDatabase() { return clickhouseServer::execute; diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestAltinityClickHouseTypeMapping.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestAltinityClickHouseTypeMapping.java index a9ea166cdd19..adc430687207 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestAltinityClickHouseTypeMapping.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestAltinityClickHouseTypeMapping.java @@ -13,9 +13,21 @@ */ package io.trino.plugin.clickhouse; +import io.trino.spi.type.DecimalType; import io.trino.testing.QueryRunner; +import io.trino.testing.datatype.SqlDataTypeTest; import static io.trino.plugin.clickhouse.TestingClickHouseServer.ALTINITY_DEFAULT_IMAGE; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.UuidType.UUID; +import static io.trino.spi.type.VarcharType.VARCHAR; final class TestAltinityClickHouseTypeMapping extends BaseClickHouseTypeMapping @@ -27,4 +39,25 @@ protected QueryRunner createQueryRunner() clickhouseServer = closeAfterClass(new TestingClickHouseServer(ALTINITY_DEFAULT_IMAGE)); return ClickHouseQueryRunner.builder(clickhouseServer).build(); } -} + + @Override + protected SqlDataTypeTest addClickhouseCreateAndInsertDataTypeTests(SqlDataTypeTest dataTypeTest) + { + // Older versions of Altinity are missing a String -> IPADDRESS cast which results in the test setup for those tests failing + dataTypeTest.addRoundTrip("Map(Int32, Bool)", "map(42, true)", mapWithValueType(BOOLEAN), "MAP(ARRAY[42], ARRAY[true])") + .addRoundTrip("Map(Int32, Int8)", "map(42, 0)", mapWithValueType(TINYINT), "MAP(ARRAY[42], ARRAY[TINYINT '0'])") + .addRoundTrip("Map(Int32, Int32)", "map(42, 1)", mapWithValueType(INTEGER), "MAP(ARRAY[42], ARRAY[1])") + .addRoundTrip("Map(Int32, Int64)", "map(42, 2)", mapWithValueType(BIGINT), "MAP(ARRAY[42], ARRAY[BIGINT '2'])") + .addRoundTrip("Map(Int32, UInt8)", "map(42, 3)", mapWithValueType(SMALLINT), "MAP(ARRAY[42], ARRAY[SMALLINT '3'])") + .addRoundTrip("Map(Int32, UInt16)", "map(42, 4)", mapWithValueType(INTEGER), "MAP(ARRAY[42], ARRAY[4])") + .addRoundTrip("Map(Int32, UInt32)", "map(42, 5)", mapWithValueType(BIGINT), "MAP(ARRAY[42], ARRAY[BIGINT '5'])") + .addRoundTrip("Map(Int32, UInt64)", "map(42, 6)", mapWithValueType(DecimalType.createDecimalType(20, 0)), "MAP(ARRAY[42], ARRAY[CAST(6 AS DECIMAL(20, 0))])") + .addRoundTrip("Map(Int32, Float32)", "map(42, 7)", mapWithValueType(REAL), "MAP(ARRAY[42], ARRAY[REAL '7'])") + .addRoundTrip("Map(Int32, Float64)", "map(42, 8)", mapWithValueType(DOUBLE), "MAP(ARRAY[42], ARRAY[DOUBLE '8'])") + .addRoundTrip("Map(Int32, FixedString(4))", "map(42, 'nine')", mapWithValueType(VARCHAR), "MAP(ARRAY[42], ARRAY[CAST('nine' AS VARCHAR)])") + .addRoundTrip("Map(Int32, String)", "map(42, 'ten')", mapWithValueType(VARCHAR), "MAP(ARRAY[42], ARRAY[CAST('ten' AS VARCHAR)])") + .addRoundTrip("Map(Int32, Date)", "map(42, '2011-11-11')", mapWithValueType(DATE), "MAP(ARRAY[42], ARRAY[DATE '2011-11-11'])") + .addRoundTrip("Map(Int32, Enum('hello' = 1, 'world' = 2))", "map(42, 'world')", mapWithValueType(VARCHAR), "MAP(ARRAY[42], ARRAY[CAST('world' AS VARCHAR)])") + .addRoundTrip("Map(Int32, UUID)", "map(42, '92d3f742-b13c-4d8e-9d7a-1130d2d31980')", mapWithValueType(UUID), "MAP(ARRAY[42], ARRAY[UUID '92d3f742-b13c-4d8e-9d7a-1130d2d31980'])"); + return dataTypeTest; + }} diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java index 99aa2ff7e1e3..35e5810fdffa 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java @@ -72,6 +72,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN_WITH_LIKE, SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY, SUPPORTS_TOPN_PUSHDOWN, + SUPPORTS_MAP_TYPE, SUPPORTS_TRUNCATE -> true; case SUPPORTS_AGGREGATION_PUSHDOWN_REGRESSION, SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV, @@ -79,7 +80,6 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_ARRAY, SUPPORTS_DELETE, SUPPORTS_DROP_NOT_NULL_CONSTRAINT, - SUPPORTS_MAP_TYPE, SUPPORTS_NEGATIVE_DATE, SUPPORTS_ROW_TYPE, SUPPORTS_SET_COLUMN_TYPE, @@ -665,6 +665,28 @@ public void testInsertIntoNotNullColumn() } } + @Test + @Override + public void testInsertMap() + { + try (TestTable table = newTrinoTable("test_insert_map_", "(col map(INTEGER, BIGINT) NOT NULL)")) { + assertUpdate("INSERT INTO " + table.getName() + " VALUES map(ARRAY[1], ARRAY[BIGINT '123456789'])", 1); + assertThat(query("SELECT col[1] FROM " + table.getName())) + .matches("VALUES BIGINT '123456789'"); + } + } + + @Test + public void testMapPredicatePushdown() + { + try (TestTable table = newTrinoTable("test_map_predicate_pushdown", "(id INT, map_t map(INTEGER, BIGINT) NOT NULL)")) { + assertUpdate("INSERT INTO " + table.getName() + " VALUES (1, map(ARRAY[1], ARRAY[BIGINT '123456789']))", 1); + assertThat(query("SELECT id FROM " + table.getName() + " WHERE map_t[1] = BIGINT '123456789'")) + .matches("VALUES 1") + .isFullyPushedDown(); + } + } + @Override protected String errorMessageForCreateTableAsSelectNegativeDate(String date) { @@ -951,19 +973,19 @@ a_enum_2 Enum('hello', 'world', 'a', 'b', 'c', '%', '_')) assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string = a_string_alias")).isFullyPushedDown(); assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string = a_string_alias" + withConnectorExpression)).isFullyPushedDown(); - assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string = a_enum_1")).isNotFullyPushedDown(FilterNode.class); - assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string = a_enum_1" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string = a_enum_1")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string = a_enum_1" + withConnectorExpression)).isFullyPushedDown(); assertThat(query(convertToVarchar, "SELECT some_column FROM " + table.getName() + " WHERE a_string = unsupported_1")).isNotFullyPushedDown(FilterNode.class); assertThat(query(convertToVarchar, "SELECT some_column FROM " + table.getName() + " WHERE a_string = unsupported_1" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 = 'hello'")).isNotFullyPushedDown(FilterNode.class); - assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 = 'hello'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 = 'hello'" + withConnectorExpression)).isFullyPushedDown(); assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 = 'not_a_value'")).isNotFullyPushedDown(FilterNode.class); - assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 = 'not_a_value'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 = 'not_a_value'" + withConnectorExpression)).isFullyPushedDown(); // pushdown of a condition, both sides of the same native type, which is mapped to varchar, // not allowed because some operations (e.g. inequalities) may not be allowed in the native system on an unknown native types - assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 = a_enum_2")).isNotFullyPushedDown(FilterNode.class); - assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 = a_enum_2" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 = a_enum_2")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 = a_enum_2" + withConnectorExpression)).isFullyPushedDown(); // pushdown of a condition, both sides of the same native type, which is mapped to varchar, // not allowed because some operations (e.g. inequalities) may not be allowed in the native system on an unknown native types