Skip to content

Commit 090c150

Browse files
committed
feat: Register Spark array_min/max function with Orderable types
1 parent 4b1740d commit 090c150

File tree

5 files changed

+124
-6
lines changed

5 files changed

+124
-6
lines changed

velox/docs/functions/spark/array.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ Array Functions
9494
SELECT array_max(array(-1, -2, NULL)); -- -1
9595
SELECT array_max(array()); -- NULL
9696
SELECT array_max(array(-0.0001, -0.0002, -0.0003, float('nan'))); -- NaN
97+
SELECT array_max(array(array(1), array(NULL))) -- array(1)
9798

9899
.. spark:function:: array_min(array(E)) -> E
99100
@@ -105,8 +106,9 @@ Array Functions
105106
SELECT array_min(array(-1, -2, NULL)); -- -2
106107
SELECT array_min(array(NULL, NULL)); -- NULL
107108
SELECT array_min(array()); -- NULL
108-
SELECT array_min(array(4.0, float('nan')]); -- 4.0
109+
SELECT array_min(array(4.0, float('nan'))); -- 4.0
109110
SELECT array_min(array(NULL, float('nan'))); -- NaN
111+
SELECT array_min(array(array(1), array(NULL))) -- array(NULL)
110112

111113
.. spark:function:: array_position(x, element) -> bigint
112114

velox/functions/sparksql/ArrayMinMaxFunction.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,50 @@ struct ArrayMinMaxFunction {
106106
assign(out, currentValue);
107107
return true;
108108
}
109+
110+
bool compare(
111+
exec::GenericView currentValue,
112+
exec::GenericView candidateValue) {
113+
static constexpr CompareFlags kFlags = {
114+
.nullHandlingMode = CompareFlags::NullHandlingMode::kNullAsValue};
115+
116+
auto compareResult = candidateValue.compare(currentValue, kFlags).value();
117+
if constexpr (isMax) {
118+
return compareResult > 0;
119+
} else {
120+
return compareResult < 0;
121+
}
122+
}
123+
124+
bool call(
125+
out_type<Orderable<T1>>& out,
126+
const arg_type<Array<Orderable<T1>>>& array) {
127+
// Result is null if array is empty.
128+
if (array.size() == 0) {
129+
return false;
130+
}
131+
132+
int currentIndex = -1;
133+
for (auto i = 0; i < array.size(); i++) {
134+
if (array[i].has_value()) {
135+
if (currentIndex == -1) {
136+
currentIndex = i;
137+
} else {
138+
auto currentValue = array[currentIndex].value();
139+
auto candidateValue = array[i].value();
140+
if (compare(currentValue, candidateValue)) {
141+
currentIndex = i;
142+
}
143+
}
144+
}
145+
}
146+
if (currentIndex == -1) {
147+
// If array contains only NULL elements, return NULL.
148+
return false;
149+
}
150+
out.copy_from(array[currentIndex].value());
151+
return true;
152+
}
109153
};
110154

111155
template <typename TExecCtx>

velox/functions/sparksql/registration/RegisterArray.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,11 @@ inline void registerArrayMinMaxFunctions(const std::string& prefix) {
103103
registerArrayMinMaxFunctions<float>(prefix);
104104
registerArrayMinMaxFunctions<double>(prefix);
105105
registerArrayMinMaxFunctions<bool>(prefix);
106+
registerArrayMinMaxFunctions<Varbinary>(prefix);
106107
registerArrayMinMaxFunctions<Varchar>(prefix);
107108
registerArrayMinMaxFunctions<Timestamp>(prefix);
108109
registerArrayMinMaxFunctions<Date>(prefix);
110+
registerArrayMinMaxFunctions<Orderable<T1>>(prefix);
109111
}
110112

111113
template <typename T>

velox/functions/sparksql/tests/ArrayMaxTest.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ namespace {
2929
class ArrayMaxTest : public SparkFunctionBaseTest {
3030
protected:
3131
template <typename T>
32-
std::optional<T> arrayMax(const std::vector<std::optional<T>>& input) {
32+
std::optional<T> arrayMax(
33+
const std::vector<std::optional<T>>& input,
34+
const TypePtr& type = ARRAY(CppToType<T>::create())) {
3335
auto row = makeRowVector({makeNullableArrayVector(
34-
std::vector<std::vector<std::optional<T>>>{input})});
36+
std::vector<std::vector<std::optional<T>>>{input}, type)});
3537
return evaluateOnce<T>("array_max(C0)", row);
3638
}
3739
};
@@ -47,6 +49,17 @@ TEST_F(ArrayMaxTest, boolean) {
4749
EXPECT_EQ(arrayMax<bool>({true, true, true}), true);
4850
}
4951

52+
TEST_F(ArrayMaxTest, varbinary) {
53+
EXPECT_EQ(arrayMax<std::string>({"red", "blue"}, ARRAY(VARBINARY())), "red");
54+
EXPECT_EQ(
55+
arrayMax<std::string>(
56+
{std::nullopt, "blue", "yellow", "orange"}, ARRAY(VARBINARY())),
57+
"yellow");
58+
EXPECT_EQ(arrayMax<std::string>({}, ARRAY(VARBINARY())), std::nullopt);
59+
EXPECT_EQ(
60+
arrayMax<std::string>({std::nullopt}, ARRAY(VARBINARY())), std::nullopt);
61+
}
62+
5063
TEST_F(ArrayMaxTest, varchar) {
5164
EXPECT_EQ(arrayMax<std::string>({"red", "blue"}), "red");
5265
EXPECT_EQ(
@@ -87,6 +100,26 @@ TEST_F(ArrayMaxTest, timestamp) {
87100
EXPECT_EQ(arrayMax<Timestamp>({ts(0), std::nullopt}), ts(0));
88101
}
89102

103+
TEST_F(ArrayMaxTest, complexTypes) {
104+
auto testExpression = [&](const VectorPtr& input, const VectorPtr& expected) {
105+
auto result = evaluate("array_max(c0)", makeRowVector({input}));
106+
assertEqualVectors(expected, result);
107+
};
108+
testExpression(
109+
makeNestedArrayVectorFromJson<int64_t>(
110+
{"[[1, 1, 1], [1, 2, 2], [1, 3, 1]]"}),
111+
makeArrayVectorFromJson<int64_t>({"[1, 3, 1]"}));
112+
113+
testExpression(
114+
makeNestedArrayVectorFromJson<int64_t>(
115+
{"[[1, null], [null, 2], [null, null]]"}),
116+
makeArrayVectorFromJson<int64_t>({"[1, null]"}));
117+
118+
testExpression(
119+
makeNestedArrayVectorFromJson<int64_t>({"[null, null]"}),
120+
makeArrayVectorFromJson<int64_t>({"null"}));
121+
}
122+
90123
template <typename Type>
91124
class ArrayMaxIntegralTest : public ArrayMaxTest {
92125
public:

velox/functions/sparksql/tests/ArrayMinTest.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ namespace {
3030
class ArrayMinTest : public SparkFunctionBaseTest {
3131
protected:
3232
template <typename T>
33-
std::optional<T> arrayMin(const std::vector<std::optional<T>>& input) {
33+
std::optional<T> arrayMin(
34+
const std::vector<std::optional<T>>& input,
35+
const TypePtr& type = ARRAY(CppToType<T>::create())) {
3436
auto row = makeRowVector({makeNullableArrayVector(
35-
std::vector<std::vector<std::optional<T>>>{input})});
37+
std::vector<std::vector<std::optional<T>>>{input}, type)});
3638
return evaluateOnce<T>("array_min(C0)", row);
3739
}
3840
};
@@ -46,7 +48,18 @@ TEST_F(ArrayMinTest, boolean) {
4648
EXPECT_EQ(arrayMin<bool>({std::nullopt, true, false, true}), false);
4749
EXPECT_EQ(arrayMin<bool>({false, false, false}), false);
4850
EXPECT_EQ(arrayMin<bool>({true, true, true}), true);
49-
} // namespace
51+
}
52+
53+
TEST_F(ArrayMinTest, varbinary) {
54+
EXPECT_EQ(arrayMin<std::string>({"red", "blue"}, ARRAY(VARBINARY())), "blue");
55+
EXPECT_EQ(
56+
arrayMin<std::string>(
57+
{std::nullopt, "blue", "yellow", "orange"}, ARRAY(VARBINARY())),
58+
"blue");
59+
EXPECT_EQ(arrayMin<std::string>({}, ARRAY(VARBINARY())), std::nullopt);
60+
EXPECT_EQ(
61+
arrayMin<std::string>({std::nullopt}, ARRAY(VARBINARY())), std::nullopt);
62+
}
5063

5164
TEST_F(ArrayMinTest, varchar) {
5265
EXPECT_EQ(arrayMin<std::string>({"red", "blue"}), "blue");
@@ -88,6 +101,30 @@ TEST_F(ArrayMinTest, timestamp) {
88101
EXPECT_EQ(arrayMin<Timestamp>({ts(0), std::nullopt}), ts(0));
89102
}
90103

104+
TEST_F(ArrayMinTest, complexTypes) {
105+
auto testExpression = [&](const VectorPtr& input, const VectorPtr& expected) {
106+
auto result = evaluate("array_min(c0)", makeRowVector({input}));
107+
assertEqualVectors(expected, result);
108+
};
109+
testExpression(
110+
makeNestedArrayVectorFromJson<int64_t>(
111+
{"[[1, 1, 1], [1, 1, 2], [1, 3, 1]]"}),
112+
makeArrayVectorFromJson<int64_t>({"[1, 1, 1]"}));
113+
114+
testExpression(
115+
makeNestedArrayVectorFromJson<int64_t>(
116+
{"[[1, null], [null, 2], [null, null]]"}),
117+
makeArrayVectorFromJson<int64_t>({"[null, null]"}));
118+
119+
testExpression(
120+
makeNestedArrayVectorFromJson<int64_t>({"[[1], [null], []]"}),
121+
makeArrayVectorFromJson<int64_t>({"[]"}));
122+
123+
testExpression(
124+
makeNestedArrayVectorFromJson<int64_t>({"[null, null]"}),
125+
makeArrayVectorFromJson<int64_t>({"null"}));
126+
}
127+
91128
template <typename Type>
92129
class ArrayMinIntegralTest : public ArrayMinTest {
93130
public:

0 commit comments

Comments
 (0)