Skip to content

Commit

Permalink
feat: Register Spark array_min/max function with Orderable types
Browse files Browse the repository at this point in the history
  • Loading branch information
zhli1142015 committed Mar 11, 2025
1 parent 4b1740d commit 090c150
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 6 deletions.
4 changes: 3 additions & 1 deletion velox/docs/functions/spark/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ Array Functions
SELECT array_max(array(-1, -2, NULL)); -- -1
SELECT array_max(array()); -- NULL
SELECT array_max(array(-0.0001, -0.0002, -0.0003, float('nan'))); -- NaN
SELECT array_max(array(array(1), array(NULL))) -- array(1)

.. spark:function:: array_min(array(E)) -> E
Expand All @@ -105,8 +106,9 @@ Array Functions
SELECT array_min(array(-1, -2, NULL)); -- -2
SELECT array_min(array(NULL, NULL)); -- NULL
SELECT array_min(array()); -- NULL
SELECT array_min(array(4.0, float('nan')]); -- 4.0
SELECT array_min(array(4.0, float('nan'))); -- 4.0
SELECT array_min(array(NULL, float('nan'))); -- NaN
SELECT array_min(array(array(1), array(NULL))) -- array(NULL)

.. spark:function:: array_position(x, element) -> bigint
Expand Down
44 changes: 44 additions & 0 deletions velox/functions/sparksql/ArrayMinMaxFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,50 @@ struct ArrayMinMaxFunction {
assign(out, currentValue);
return true;
}

bool compare(
exec::GenericView currentValue,
exec::GenericView candidateValue) {
static constexpr CompareFlags kFlags = {
.nullHandlingMode = CompareFlags::NullHandlingMode::kNullAsValue};

auto compareResult = candidateValue.compare(currentValue, kFlags).value();
if constexpr (isMax) {
return compareResult > 0;
} else {
return compareResult < 0;
}
}

bool call(
out_type<Orderable<T1>>& out,
const arg_type<Array<Orderable<T1>>>& array) {
// Result is null if array is empty.
if (array.size() == 0) {
return false;
}

int currentIndex = -1;
for (auto i = 0; i < array.size(); i++) {
if (array[i].has_value()) {
if (currentIndex == -1) {
currentIndex = i;
} else {
auto currentValue = array[currentIndex].value();
auto candidateValue = array[i].value();
if (compare(currentValue, candidateValue)) {
currentIndex = i;
}
}
}
}
if (currentIndex == -1) {
// If array contains only NULL elements, return NULL.
return false;
}
out.copy_from(array[currentIndex].value());
return true;
}
};

template <typename TExecCtx>
Expand Down
2 changes: 2 additions & 0 deletions velox/functions/sparksql/registration/RegisterArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ inline void registerArrayMinMaxFunctions(const std::string& prefix) {
registerArrayMinMaxFunctions<float>(prefix);
registerArrayMinMaxFunctions<double>(prefix);
registerArrayMinMaxFunctions<bool>(prefix);
registerArrayMinMaxFunctions<Varbinary>(prefix);
registerArrayMinMaxFunctions<Varchar>(prefix);
registerArrayMinMaxFunctions<Timestamp>(prefix);
registerArrayMinMaxFunctions<Date>(prefix);
registerArrayMinMaxFunctions<Orderable<T1>>(prefix);
}

template <typename T>
Expand Down
37 changes: 35 additions & 2 deletions velox/functions/sparksql/tests/ArrayMaxTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ namespace {
class ArrayMaxTest : public SparkFunctionBaseTest {
protected:
template <typename T>
std::optional<T> arrayMax(const std::vector<std::optional<T>>& input) {
std::optional<T> arrayMax(
const std::vector<std::optional<T>>& input,
const TypePtr& type = ARRAY(CppToType<T>::create())) {
auto row = makeRowVector({makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{input})});
std::vector<std::vector<std::optional<T>>>{input}, type)});
return evaluateOnce<T>("array_max(C0)", row);
}
};
Expand All @@ -47,6 +49,17 @@ TEST_F(ArrayMaxTest, boolean) {
EXPECT_EQ(arrayMax<bool>({true, true, true}), true);
}

TEST_F(ArrayMaxTest, varbinary) {
EXPECT_EQ(arrayMax<std::string>({"red", "blue"}, ARRAY(VARBINARY())), "red");
EXPECT_EQ(
arrayMax<std::string>(
{std::nullopt, "blue", "yellow", "orange"}, ARRAY(VARBINARY())),
"yellow");
EXPECT_EQ(arrayMax<std::string>({}, ARRAY(VARBINARY())), std::nullopt);
EXPECT_EQ(
arrayMax<std::string>({std::nullopt}, ARRAY(VARBINARY())), std::nullopt);
}

TEST_F(ArrayMaxTest, varchar) {
EXPECT_EQ(arrayMax<std::string>({"red", "blue"}), "red");
EXPECT_EQ(
Expand Down Expand Up @@ -87,6 +100,26 @@ TEST_F(ArrayMaxTest, timestamp) {
EXPECT_EQ(arrayMax<Timestamp>({ts(0), std::nullopt}), ts(0));
}

TEST_F(ArrayMaxTest, complexTypes) {
auto testExpression = [&](const VectorPtr& input, const VectorPtr& expected) {
auto result = evaluate("array_max(c0)", makeRowVector({input}));
assertEqualVectors(expected, result);
};
testExpression(
makeNestedArrayVectorFromJson<int64_t>(
{"[[1, 1, 1], [1, 2, 2], [1, 3, 1]]"}),
makeArrayVectorFromJson<int64_t>({"[1, 3, 1]"}));

testExpression(
makeNestedArrayVectorFromJson<int64_t>(
{"[[1, null], [null, 2], [null, null]]"}),
makeArrayVectorFromJson<int64_t>({"[1, null]"}));

testExpression(
makeNestedArrayVectorFromJson<int64_t>({"[null, null]"}),
makeArrayVectorFromJson<int64_t>({"null"}));
}

template <typename Type>
class ArrayMaxIntegralTest : public ArrayMaxTest {
public:
Expand Down
43 changes: 40 additions & 3 deletions velox/functions/sparksql/tests/ArrayMinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ namespace {
class ArrayMinTest : public SparkFunctionBaseTest {
protected:
template <typename T>
std::optional<T> arrayMin(const std::vector<std::optional<T>>& input) {
std::optional<T> arrayMin(
const std::vector<std::optional<T>>& input,
const TypePtr& type = ARRAY(CppToType<T>::create())) {
auto row = makeRowVector({makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{input})});
std::vector<std::vector<std::optional<T>>>{input}, type)});
return evaluateOnce<T>("array_min(C0)", row);
}
};
Expand All @@ -46,7 +48,18 @@ TEST_F(ArrayMinTest, boolean) {
EXPECT_EQ(arrayMin<bool>({std::nullopt, true, false, true}), false);
EXPECT_EQ(arrayMin<bool>({false, false, false}), false);
EXPECT_EQ(arrayMin<bool>({true, true, true}), true);
} // namespace
}

TEST_F(ArrayMinTest, varbinary) {
EXPECT_EQ(arrayMin<std::string>({"red", "blue"}, ARRAY(VARBINARY())), "blue");
EXPECT_EQ(
arrayMin<std::string>(
{std::nullopt, "blue", "yellow", "orange"}, ARRAY(VARBINARY())),
"blue");
EXPECT_EQ(arrayMin<std::string>({}, ARRAY(VARBINARY())), std::nullopt);
EXPECT_EQ(
arrayMin<std::string>({std::nullopt}, ARRAY(VARBINARY())), std::nullopt);
}

TEST_F(ArrayMinTest, varchar) {
EXPECT_EQ(arrayMin<std::string>({"red", "blue"}), "blue");
Expand Down Expand Up @@ -88,6 +101,30 @@ TEST_F(ArrayMinTest, timestamp) {
EXPECT_EQ(arrayMin<Timestamp>({ts(0), std::nullopt}), ts(0));
}

TEST_F(ArrayMinTest, complexTypes) {
auto testExpression = [&](const VectorPtr& input, const VectorPtr& expected) {
auto result = evaluate("array_min(c0)", makeRowVector({input}));
assertEqualVectors(expected, result);
};
testExpression(
makeNestedArrayVectorFromJson<int64_t>(
{"[[1, 1, 1], [1, 1, 2], [1, 3, 1]]"}),
makeArrayVectorFromJson<int64_t>({"[1, 1, 1]"}));

testExpression(
makeNestedArrayVectorFromJson<int64_t>(
{"[[1, null], [null, 2], [null, null]]"}),
makeArrayVectorFromJson<int64_t>({"[null, null]"}));

testExpression(
makeNestedArrayVectorFromJson<int64_t>({"[[1], [null], []]"}),
makeArrayVectorFromJson<int64_t>({"[]"}));

testExpression(
makeNestedArrayVectorFromJson<int64_t>({"[null, null]"}),
makeArrayVectorFromJson<int64_t>({"null"}));
}

template <typename Type>
class ArrayMinIntegralTest : public ArrayMinTest {
public:
Expand Down

0 comments on commit 090c150

Please sign in to comment.