Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public enum BuiltinFunctionName {
ARRAY(FunctionName.of("array")),
ARRAY_LENGTH(FunctionName.of("array_length")),
MAP_CONCAT(FunctionName.of("map_concat"), true),
MAP_APPEND(FunctionName.of("map_append"), true),
MVAPPEND(FunctionName.of("mvappend")),
MVJOIN(FunctionName.of("mvjoin")),
FORALL(FunctionName.of("forall")),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function.CollectionUDF;

import java.util.ArrayList;
import java.util.List;

/** Core logic for `mvappend` command to collect elements from list of args */
public class MVAppendCore {

/**
* Collect non-null elements from `args`. If an item is a list, it will collect non-null elements
* of the list. See {@ref MVAppendFunctionImplTest} for detailed behavior.
*/
public static List<Object> collectElements(Object... args) {
List<Object> elements = new ArrayList<>();

for (Object arg : args) {
if (arg == null) {
continue;
} else if (arg instanceof List) {
addListElements((List<?>) arg, elements);
} else {
elements.add(arg);
}
}

return elements.isEmpty() ? null : elements;
}

private static void addListElements(List<?> list, List<Object> elements) {
for (Object item : list) {
if (item != null) {
elements.add(item);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import static org.apache.calcite.sql.type.SqlTypeUtil.createArrayType;

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
import org.apache.calcite.adapter.enumerable.NullPolicy;
Expand Down Expand Up @@ -98,33 +97,6 @@ public Expression implement(
}

public static Object mvappend(Object... args) {
List<Object> elements = collectElements(args);
return elements.isEmpty() ? null : elements;
}

private static List<Object> collectElements(Object... args) {
List<Object> elements = new ArrayList<>();

for (Object arg : args) {
if (arg == null) {
continue;
}

if (arg instanceof List) {
addListElements((List<?>) arg, elements);
} else {
elements.add(arg);
}
}

return elements;
}

private static void addListElements(List<?> list, List<Object> elements) {
for (Object item : list) {
if (item != null) {
elements.add(item);
}
}
return MVAppendCore.collectElements(args);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function.CollectionUDF;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
import org.apache.calcite.adapter.enumerable.NullPolicy;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.opensearch.sql.expression.function.ImplementorUDF;
import org.opensearch.sql.expression.function.UDFOperandMetadata;

/**
* MapAppend function that merges two maps. All the values will be converted to list for type
* consistency.
*/
public class MapAppendFunctionImpl extends ImplementorUDF {

public MapAppendFunctionImpl() {
super(new MapAppendImplementor(), NullPolicy.ALL);
}

@Override
public SqlReturnTypeInference getReturnTypeInference() {
return sqlOperatorBinding -> {
RelDataTypeFactory typeFactory = sqlOperatorBinding.getTypeFactory();
return typeFactory.createMapType(
typeFactory.createSqlType(SqlTypeName.VARCHAR),
typeFactory.createSqlType(SqlTypeName.ANY));
};
}

@Override
public UDFOperandMetadata getOperandMetadata() {
return null;
}

public static class MapAppendImplementor implements NotNullImplementor {
@Override
public Expression implement(
RexToLixTranslator translator, RexCall call, List<Expression> translatedOperands) {
if (translatedOperands.size() != 2) {
throw new IllegalArgumentException("MAP_APPEND function requires exactly 2 arguments");
}

return Expressions.call(
Types.lookupMethod(MapAppendFunctionImpl.class, "mapAppend", Object.class, Object.class),
translatedOperands.get(0),
translatedOperands.get(1));
}
}

public static Object mapAppend(Object map1, Object map2) {
if (map1 == null && map2 == null) {
return null;
}
if (map1 == null) {
return mapAppendImpl(verifyMap(map2));
}
if (map2 == null) {
return mapAppendImpl(verifyMap(map1));
}

return mapAppendImpl(verifyMap(map1), verifyMap(map2));
}

@SuppressWarnings("unchecked")
private static Map<String, Object> verifyMap(Object map) {
if (!(map instanceof Map)) {
throw new IllegalArgumentException(
"MAP_APPEND function requires both arguments to be MAP type");
}
return (Map<String, Object>) map;
}

static Map<String, Object> mapAppendImpl(Map<String, Object> map) {
Map<String, Object> result = new HashMap<>();
for (String key : map.keySet()) {
result.put(key, MVAppendCore.collectElements(map.get(key)));
}
return result;
}

static Map<String, Object> mapAppendImpl(
Map<String, Object> firstMap, Map<String, Object> secondMap) {
Map<String, Object> result = new HashMap<>();

for (String key : mergeKeys(firstMap, secondMap)) {
result.put(key, MVAppendCore.collectElements(firstMap.get(key), secondMap.get(key)));
}

return result;
}

private static Set<String> mergeKeys(
Map<String, Object> firstMap, Map<String, Object> secondMap) {
Set<String> keys = new HashSet<>();
keys.addAll(firstMap.keySet());
keys.addAll(secondMap.keySet());
return keys;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.sql.expression.function.CollectionUDF.FilterFunctionImpl;
import org.opensearch.sql.expression.function.CollectionUDF.ForallFunctionImpl;
import org.opensearch.sql.expression.function.CollectionUDF.MVAppendFunctionImpl;
import org.opensearch.sql.expression.function.CollectionUDF.MapAppendFunctionImpl;
import org.opensearch.sql.expression.function.CollectionUDF.ReduceFunctionImpl;
import org.opensearch.sql.expression.function.CollectionUDF.TransformFunctionImpl;
import org.opensearch.sql.expression.function.jsonUDF.JsonAppendFunctionImpl;
Expand Down Expand Up @@ -385,6 +386,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
public static final SqlOperator EXISTS = new ExistsFunctionImpl().toUDF("exists");
public static final SqlOperator ARRAY = new ArrayFunctionImpl().toUDF("array");
public static final SqlOperator MVAPPEND = new MVAppendFunctionImpl().toUDF("mvappend");
public static final SqlOperator MAP_APPEND = new MapAppendFunctionImpl().toUDF("map_append");
public static final SqlOperator FILTER = new FilterFunctionImpl().toUDF("filter");
public static final SqlOperator TRANSFORM = new TransformFunctionImpl().toUDF("transform");
public static final SqlOperator REDUCE = new ReduceFunctionImpl().toUDF("reduce");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.LTRIM;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAKEDATE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAKETIME;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAP_APPEND;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAP_CONCAT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MATCH;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MATCH_BOOL_PREFIX;
Expand Down Expand Up @@ -860,6 +861,7 @@ void populate() {

registerOperator(ARRAY, PPLBuiltinOperators.ARRAY);
registerOperator(MVAPPEND, PPLBuiltinOperators.MVAPPEND);
registerOperator(MAP_APPEND, PPLBuiltinOperators.MAP_APPEND);
registerOperator(ARRAY_LENGTH, SqlLibraryOperators.ARRAY_LENGTH);
registerOperator(MAP_CONCAT, SqlLibraryOperators.MAP_CONCAT);
registerOperator(FORALL, PPLBuiltinOperators.FORALL);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function.CollectionUDF;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;

class MapAppendFunctionImplTest {
@Test
void testMapAppendWithNonOverlappingKeys() {
Map<String, Object> map1 = getMap1();
Map<String, Object> map2 = getMap2();

Map<String, Object> result = MapAppendFunctionImpl.mapAppendImpl(map1, map2);

assertEquals(4, result.size());
assertMapListValues(result, "a", "value1");
assertMapListValues(result, "b", "value2");
assertMapListValues(result, "c", "value3");
assertMapListValues(result, "d", "value4");
}

@Test
void testMapAppendWithOverlappingKeys() {
Map<String, Object> map1 = getMap1();
Map<String, Object> map2 = Map.of("b", "value3", "c", "value4");

Map<String, Object> result = MapAppendFunctionImpl.mapAppendImpl(map1, map2);

assertEquals(3, result.size());
assertMapListValues(result, "a", "value1");
assertMapListValues(result, "b", "value2", "value3");
assertMapListValues(result, "c", "value4");
}

@Test
void testMapAppendWithArrayValues() {
Map<String, Object> map1 = Map.of("a", List.of("item1", "item2"), "b", "single");
Map<String, Object> map2 = Map.of("a", "item3", "c", List.of("item4", "item5"));

Map<String, Object> result = MapAppendFunctionImpl.mapAppendImpl(map1, map2);

assertEquals(3, result.size());
assertMapListValues(result, "a", "item1", "item2", "item3");
assertMapListValues(result, "b", "single");
assertMapListValues(result, "c", "item4", "item5");
}

@Test
void testMapAppendWithNullValues() {
Map<String, Object> map1 = getMap1();
map1.put("b", null);
Map<String, Object> map2 = getMap2();
map2.put("b", "value2");
map2.put("a", null);

Map<String, Object> result = MapAppendFunctionImpl.mapAppendImpl(map1, map2);

assertEquals(4, result.size());
assertMapListValues(result, "a", "value1");
assertMapListValues(result, "b", "value2");
assertMapListValues(result, "c", "value3");
assertMapListValues(result, "d", "value4");
}

@Test
void testMapAppendWithSingleParam() {
Map<String, Object> map1 = getMap1();

Map<String, Object> result = MapAppendFunctionImpl.mapAppendImpl(map1);

assertEquals(2, result.size());
assertMapListValues(result, "a", "value1");
assertMapListValues(result, "b", "value2");
}

private Map<String, Object> getMap1() {
Map<String, Object> map1 = new HashMap<>();
map1.put("a", "value1");
map1.put("b", "value2");
return map1;
}

private Map<String, Object> getMap2() {
Map<String, Object> map2 = new HashMap<>();
map2.put("c", "value3");
map2.put("d", "value4");
return map2;
}

private void assertMapListValues(Map<String, Object> map, String key, Object... expectedValues) {
Object val = map.get(key);
assertTrue(val instanceof List);
List<Object> result = (List<Object>) val;
assertEquals(expectedValues.length, result.size());
for (int i = 0; i < expectedValues.length; i++) {
assertEquals(expectedValues[i], result.get(i));
}
}
}
Loading
Loading