Skip to content

Commit b520b7c

Browse files
committed
Add MAP type support to the ClickHouse connector
1 parent 42b7ca9 commit b520b7c

File tree

2 files changed

+143
-16
lines changed

2 files changed

+143
-16
lines changed

plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java

Lines changed: 124 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import com.clickhouse.data.ClickHouseColumn;
1717
import com.clickhouse.data.ClickHouseDataType;
1818
import com.clickhouse.data.ClickHouseVersion;
19+
import com.clickhouse.jdbc.JdbcTypeMapping;
1920
import com.google.common.base.Enums;
2021
import com.google.common.base.Splitter;
2122
import com.google.common.collect.ImmutableList;
@@ -44,6 +45,7 @@
4445
import io.trino.plugin.jdbc.JdbcTypeHandle;
4546
import io.trino.plugin.jdbc.LongReadFunction;
4647
import io.trino.plugin.jdbc.LongWriteFunction;
48+
import io.trino.plugin.jdbc.ObjectReadFunction;
4749
import io.trino.plugin.jdbc.ObjectWriteFunction;
4850
import io.trino.plugin.jdbc.QueryBuilder;
4951
import io.trino.plugin.jdbc.RemoteTableName;
@@ -61,7 +63,10 @@
6163
import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder;
6264
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
6365
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
66+
import io.trino.spi.StandardErrorCode;
6467
import io.trino.spi.TrinoException;
68+
import io.trino.spi.block.BlockBuilder;
69+
import io.trino.spi.block.SqlMap;
6570
import io.trino.spi.connector.AggregateFunction;
6671
import io.trino.spi.connector.ColumnHandle;
6772
import io.trino.spi.connector.ColumnMetadata;
@@ -74,9 +79,11 @@
7479
import io.trino.spi.type.DecimalType;
7580
import io.trino.spi.type.Decimals;
7681
import io.trino.spi.type.Int128;
82+
import io.trino.spi.type.MapType;
7783
import io.trino.spi.type.StandardTypes;
7884
import io.trino.spi.type.Type;
7985
import io.trino.spi.type.TypeManager;
86+
import io.trino.spi.type.TypeOperators;
8087
import io.trino.spi.type.TypeSignature;
8188
import io.trino.spi.type.VarbinaryType;
8289
import io.trino.spi.type.VarcharType;
@@ -99,6 +106,7 @@
99106
import java.time.LocalDateTime;
100107
import java.time.ZonedDateTime;
101108
import java.util.Collection;
109+
import java.util.HashMap;
102110
import java.util.List;
103111
import java.util.Map;
104112
import java.util.Map.Entry;
@@ -164,6 +172,7 @@
164172
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
165173
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
166174
import static io.trino.spi.StandardErrorCode.SCHEMA_NOT_EMPTY;
175+
import static io.trino.spi.block.MapValueBuilder.buildMapValue;
167176
import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE;
168177
import static io.trino.spi.type.BigintType.BIGINT;
169178
import static io.trino.spi.type.BooleanType.BOOLEAN;
@@ -212,6 +221,7 @@ public class ClickHouseClient
212221

213222
private final ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter;
214223
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;
224+
private final TypeOperators typeOperators;
215225
private final Type uuidType;
216226
private final Type ipAddressType;
217227
private final AtomicReference<ClickHouseVersion> clickHouseVersion = new AtomicReference<>();
@@ -226,6 +236,7 @@ public ClickHouseClient(
226236
RemoteQueryModifier queryModifier)
227237
{
228238
super("\"", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, false);
239+
this.typeOperators = typeManager.getTypeOperators();
229240
this.uuidType = typeManager.getType(new TypeSignature(StandardTypes.UUID));
230241
this.ipAddressType = typeManager.getType(new TypeSignature(StandardTypes.IPADDRESS));
231242
JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
@@ -630,14 +641,27 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
630641
{
631642
String jdbcTypeName = typeHandle.jdbcTypeName()
632643
.orElseThrow(() -> new TrinoException(JDBC_ERROR, "Type name is missing: " + typeHandle));
633-
634644
Optional<ColumnMapping> mapping = getForcedMappingToVarchar(typeHandle);
635645
if (mapping.isPresent()) {
636646
return mapping;
637647
}
638648

649+
Optional<ColumnMapping> columnMapping = toColumnMapping(session, jdbcTypeName, typeHandle.jdbcType(), typeHandle.decimalDigits(), typeHandle.columnSize());
650+
if (columnMapping.isEmpty() && getUnsupportedTypeHandling(session) == CONVERT_TO_VARCHAR) {
651+
return mapToUnboundedVarchar(typeHandle);
652+
}
653+
return columnMapping;
654+
}
655+
656+
private Optional<ColumnMapping> toColumnMapping(
657+
ConnectorSession session,
658+
String typeName,
659+
int jdbcType,
660+
Optional<Integer> decimalDigits,
661+
Optional<Integer> columnSize)
662+
{
639663
ClickHouseVersion version = getClickHouseServerVersion(session);
640-
ClickHouseColumn column = ClickHouseColumn.of("", jdbcTypeName);
664+
ClickHouseColumn column = ClickHouseColumn.of("", typeName);
641665
ClickHouseDataType columnDataType = column.getDataType();
642666
switch (columnDataType) {
643667
case Bool:
@@ -677,11 +701,13 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
677701
return Optional.of(varbinaryColumnMapping());
678702
case UUID:
679703
return Optional.of(uuidColumnMapping());
704+
case Map:
705+
return mapColumnMapping(session, column.getKeyInfo(), column.getValueInfo());
680706
default:
681707
// no-op
682708
}
683709

684-
switch (typeHandle.jdbcType()) {
710+
switch (jdbcType) {
685711
case Types.TINYINT:
686712
return Optional.of(tinyintColumnMapping());
687713

@@ -706,16 +732,13 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
706732
return Optional.of(doubleColumnMapping());
707733

708734
case Types.DECIMAL:
709-
int decimalDigits = typeHandle.requiredDecimalDigits();
710-
int precision = typeHandle.requiredColumnSize();
711-
712735
ColumnMapping decimalColumnMapping;
713-
if (getDecimalRounding(session) == ALLOW_OVERFLOW && precision > Decimals.MAX_PRECISION) {
714-
int scale = Math.min(decimalDigits, getDecimalDefaultScale(session));
736+
if (getDecimalRounding(session) == ALLOW_OVERFLOW && columnSize.orElseThrow() > Decimals.MAX_PRECISION) {
737+
int scale = Math.min(decimalDigits.orElseThrow(), getDecimalDefaultScale(session));
715738
decimalColumnMapping = decimalColumnMapping(createDecimalType(Decimals.MAX_PRECISION, scale), getDecimalRoundingMode(session));
716739
}
717740
else {
718-
decimalColumnMapping = decimalColumnMapping(createDecimalType(precision, max(decimalDigits, 0)));
741+
decimalColumnMapping = decimalColumnMapping(createDecimalType(columnSize.orElseThrow(), max(decimalDigits.orElseThrow(), 0)));
719742
}
720743
return Optional.of(ColumnMapping.mapping(
721744
decimalColumnMapping.getType(),
@@ -730,7 +753,7 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
730753
case Types.TIMESTAMP:
731754
if (columnDataType == ClickHouseDataType.DateTime) {
732755
// ClickHouse DateTime does not have sub-second precision
733-
verify(typeHandle.requiredDecimalDigits() == 0, "Expected 0 as timestamp precision, but got %s", typeHandle.requiredDecimalDigits());
756+
verify(decimalDigits.orElseThrow() == 0, "Expected 0 as timestamp precision, but got %s", decimalDigits.orElseThrow());
734757
return Optional.of(ColumnMapping.longMapping(
735758
TIMESTAMP_SECONDS,
736759
timestampReadFunction(TIMESTAMP_SECONDS),
@@ -742,18 +765,14 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
742765
case Types.TIMESTAMP_WITH_TIMEZONE:
743766
if (columnDataType == ClickHouseDataType.DateTime) {
744767
// ClickHouse DateTime does not have sub-second precision
745-
verify(typeHandle.requiredDecimalDigits() == 0, "Expected 0 as timestamp with time zone precision, but got %s", typeHandle.requiredDecimalDigits());
768+
verify(decimalDigits.orElseThrow() == 0, "Expected 0 as timestamp with time zone precision, but got %s", decimalDigits.orElseThrow());
746769
return Optional.of(ColumnMapping.longMapping(
747770
TIMESTAMP_TZ_SECONDS,
748771
shortTimestampWithTimeZoneReadFunction(),
749772
shortTimestampWithTimeZoneWriteFunction(version, column.getTimeZone())));
750773
}
751774
}
752775

753-
if (getUnsupportedTypeHandling(session) == CONVERT_TO_VARCHAR) {
754-
return mapToUnboundedVarchar(typeHandle);
755-
}
756-
757776
return Optional.empty();
758777
}
759778

@@ -805,6 +824,12 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type)
805824
if (type.equals(uuidType)) {
806825
return WriteMapping.sliceMapping("UUID", uuidWriteFunction());
807826
}
827+
if (type instanceof MapType mapType) {
828+
WriteMapping keyMapping = toWriteMapping(session, mapType.getKeyType());
829+
WriteMapping valueMapping = toWriteMapping(session, mapType.getValueType());
830+
String dataType = "Map(%s, %s)".formatted(keyMapping.getDataType(), valueMapping.getDataType());
831+
return WriteMapping.objectMapping(dataType, mapWriteFunction());
832+
}
808833
throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type);
809834
}
810835

@@ -1003,6 +1028,90 @@ private static SliceWriteFunction uuidWriteFunction()
10031028
return (statement, index, value) -> statement.setObject(index, trinoUuidToJavaUuid(value), Types.OTHER);
10041029
}
10051030

1031+
private Optional<ColumnMapping> mapColumnMapping(ConnectorSession session, ClickHouseColumn keyColumn, ClickHouseColumn valueColumn)
1032+
{
1033+
JdbcTypeMapping typeMapping = JdbcTypeMapping.getDefaultMapping();
1034+
Optional<ColumnMapping> keyMapping = toColumnMapping(
1035+
session,
1036+
keyColumn.getDataType().name(),
1037+
typeMapping.toSqlType(keyColumn, Map.of()),
1038+
Optional.of(keyColumn.getPrecision()),
1039+
Optional.of(keyColumn.getScale()));
1040+
Optional<ColumnMapping> valueMapping = toColumnMapping(
1041+
session,
1042+
valueColumn.getDataType().name(),
1043+
typeMapping.toSqlType(valueColumn, Map.of()),
1044+
Optional.of(valueColumn.getPrecision()),
1045+
Optional.of(valueColumn.getScale()));
1046+
if (keyMapping.isEmpty() || valueMapping.isEmpty()) {
1047+
return Optional.empty();
1048+
}
1049+
1050+
MapType mapType = new MapType(
1051+
keyMapping.get().getType(),
1052+
valueMapping.get().getType(),
1053+
typeOperators);
1054+
return Optional.of(ColumnMapping.objectMapping(
1055+
mapType,
1056+
ObjectReadFunction.of(SqlMap.class, (resultSet, columnIndex) -> {
1057+
Object data = resultSet.getObject(columnIndex);
1058+
if (!(data instanceof Map<?, ?> mapData)) {
1059+
throw new TrinoException(StandardErrorCode.TYPE_MISMATCH, "Expected ClickHouse to return a Map");
1060+
}
1061+
1062+
return buildMapValue(
1063+
mapType,
1064+
mapData.size(),
1065+
(keyBuilder, valueBuilder) -> {
1066+
for (Object key : mapData.keySet()) {
1067+
writeValue(keyMapping.get().getType(), keyBuilder, key);
1068+
writeValue(valueMapping.get().getType(), valueBuilder, mapData.get(key));
1069+
}
1070+
});
1071+
}),
1072+
mapWriteFunction()));
1073+
}
1074+
1075+
private static ObjectWriteFunction mapWriteFunction()
1076+
{
1077+
return ObjectWriteFunction.of(SqlMap.class, (statement, index, value) -> {
1078+
MapType mapType = (MapType) value.getMapType();
1079+
Type keyType = mapType.getKeyType();
1080+
Type valueType = mapType.getValueType();
1081+
1082+
Map<Object, Object> mapValue = new HashMap<>();
1083+
for (int position = 0; position < value.getSize(); position++) {
1084+
Object keyEntry = keyType.getObjectValue(value.getRawKeyBlock(), position);
1085+
Object valueEntry = valueType.getObjectValue(value.getRawValueBlock(), position);
1086+
mapValue.put(keyEntry, valueEntry);
1087+
}
1088+
1089+
statement.setObject(index, mapValue);
1090+
});
1091+
}
1092+
1093+
private static void writeValue(Type type, BlockBuilder blockBuilder, Object value)
1094+
{
1095+
if (value == null) {
1096+
blockBuilder.appendNull();
1097+
}
1098+
else if (type.getJavaType() == long.class) {
1099+
type.writeLong(blockBuilder, ((Number) value).longValue());
1100+
}
1101+
else if (type.getJavaType() == double.class) {
1102+
type.writeDouble(blockBuilder, ((Number) value).doubleValue());
1103+
}
1104+
else if (type.getJavaType() == boolean.class) {
1105+
type.writeBoolean(blockBuilder, (boolean) value);
1106+
}
1107+
else if (type.getJavaType() == io.airlift.slice.Slice.class) {
1108+
type.writeSlice(blockBuilder, io.airlift.slice.Slices.utf8Slice(value.toString()));
1109+
}
1110+
else {
1111+
throw new UnsupportedOperationException("Unsupported type for map key or value: " + type);
1112+
}
1113+
}
1114+
10061115
public static boolean supportsPushdown(Variable variable, RewriteContext<ParameterizedExpression> context)
10071116
{
10081117
JdbcTypeHandle typeHandle = ((JdbcColumnHandle) context.getAssignment(variable.getName()))

plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
7272
SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN_WITH_LIKE,
7373
SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY,
7474
SUPPORTS_TOPN_PUSHDOWN,
75+
SUPPORTS_MAP_TYPE,
7576
SUPPORTS_TRUNCATE -> true;
7677
case SUPPORTS_AGGREGATION_PUSHDOWN_REGRESSION,
7778
SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV,
7879
SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE,
7980
SUPPORTS_ARRAY,
8081
SUPPORTS_DELETE,
8182
SUPPORTS_DROP_NOT_NULL_CONSTRAINT,
82-
SUPPORTS_MAP_TYPE,
8383
SUPPORTS_NEGATIVE_DATE,
8484
SUPPORTS_ROW_TYPE,
8585
SUPPORTS_SET_COLUMN_TYPE,
@@ -665,6 +665,24 @@ public void testInsertIntoNotNullColumn()
665665
}
666666
}
667667

668+
@Test
669+
@Override
670+
public void testInsertMap()
671+
{
672+
// TODO: Add more types here
673+
testMapRoundTrip("INTEGER", "2");
674+
testMapRoundTrip("VARCHAR", "CAST('foobar' AS VARCHAR)");
675+
}
676+
677+
private void testMapRoundTrip(String valueType, String value)
678+
{
679+
try (TestTable table = newTrinoTable("test_insert_map_", "(col map(INTEGER, %s) NOT NULL)".formatted(valueType))) {
680+
assertUpdate("INSERT INTO " + table.getName() + " VALUES map(ARRAY[1], ARRAY[%s])".formatted(value), 1);
681+
assertThat(query("SELECT col[1] FROM " + table.getName()))
682+
.matches("VALUES " + value);
683+
}
684+
}
685+
668686
@Override
669687
protected String errorMessageForCreateTableAsSelectNegativeDate(String date)
670688
{

0 commit comments

Comments
 (0)