Skip to content

feat: Register Spark array_min/max functions with orderable types #12576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion velox/docs/functions/spark/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ Array Functions
SELECT array_max(array(-1, -2, NULL)); -- -1
SELECT array_max(array()); -- NULL
SELECT array_max(array(-0.0001, -0.0002, -0.0003, float('nan'))); -- NaN
SELECT array_max(array(array(1), array(NULL))); -- array(1)
SELECT array_max(array(array(1), array(2, 1), array(2))); -- array(2, 1)
SELECT array_max(array(array(1.0), array(1.0, 2.0), array(cast('NaN' as double)))); --array(NaN)

.. spark:function:: array_min(array(E)) -> E

Expand All @@ -105,8 +108,11 @@ Array Functions
SELECT array_min(array(-1, -2, NULL)); -- -2
SELECT array_min(array(NULL, NULL)); -- NULL
SELECT array_min(array()); -- NULL
SELECT array_min(array(4.0, float('nan')]); -- 4.0
SELECT array_min(array(4.0, float('nan'))); -- 4.0
SELECT array_min(array(NULL, float('nan'))); -- NaN
SELECT array_min(array(array(1), array(NULL))); -- array(NULL)
SELECT array_min(array(array(1), array(1, 2), array(2))); -- array(1)
SELECT array_min(array(array(1.0), array(1.0, 2.0), array(cast('NaN' as double)))); --array(1.0)

.. spark:function:: array_position(x, element) -> bigint

Expand Down
119 changes: 82 additions & 37 deletions velox/functions/sparksql/ArrayMinMaxFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,78 @@ struct ArrayMinMaxFunction {
// Results refer to strings in the first argument.
static constexpr int32_t reuse_strings_from_arg = 0;

template <typename TReturn, typename TInput>
bool call(TReturn& out, const TInput& array) {
// Result is null if array is empty.
if (array.size() == 0) {
return false;
}

if (!array.mayHaveNulls()) {
// Input array does not have nulls.
auto currentValue = *array[0];
for (auto i = 1; i < array.size(); i++) {
update(currentValue, array[i].value());
}
assign(out, currentValue);
return true;
}

// Try to find the first non-null element.
auto it = array.begin();
while (it != array.end() && !it->has_value()) {
++it;
}
// If array contains only NULL elements, return NULL.
if (it == array.end()) {
return false;
}

// Now 'it' point to the first non-null element.
auto currentValue = it->value();
++it;
while (it != array.end()) {
if (it->has_value()) {
update(currentValue, it->value());
}
++it;
}

assign(out, currentValue);
return true;
}

bool call(
out_type<Orderable<T1>>& out,
const arg_type<Array<Orderable<T1>>>& array) {
// Result is null if array is empty.
if (array.size() == 0) {
return false;
}

int currentIndex = -1;
for (auto i = 0; i < array.size(); i++) {
if (array[i].has_value()) {
if (currentIndex == -1) {
currentIndex = i;
} else {
auto currentValue = array[currentIndex].value();
auto candidateValue = array[i].value();
if (compare(currentValue, candidateValue)) {
currentIndex = i;
}
}
}
}
if (currentIndex == -1) {
// If array contains only NULL elements, return NULL.
return false;
}
out.copy_from(array[currentIndex].value());
return true;
}

private:
template <typename T>
void update(T& currentValue, const T& candidateValue) {
// NaN is greater than any non-NaN elements for double/float type.
Expand Down Expand Up @@ -66,45 +138,18 @@ struct ArrayMinMaxFunction {
out.setNoCopy(value);
}

template <typename TReturn, typename TInput>
bool call(TReturn& out, const TInput& array) {
// Result is null if array is empty.
if (array.size() == 0) {
return false;
}

if (!array.mayHaveNulls()) {
// Input array does not have nulls.
auto currentValue = *array[0];
for (auto i = 1; i < array.size(); i++) {
update(currentValue, array[i].value());
}
assign(out, currentValue);
return true;
}

// Try to find the first non-null element.
auto it = array.begin();
while (it != array.end() && !it->has_value()) {
++it;
}
// If array contains only NULL elements, return NULL.
if (it == array.end()) {
return false;
}
bool compare(
exec::GenericView currentValue,
exec::GenericView candidateValue) {
static constexpr CompareFlags kFlags = {
.nullHandlingMode = CompareFlags::NullHandlingMode::kNullAsValue};

// Now 'it' point to the first non-null element.
auto currentValue = it->value();
++it;
while (it != array.end()) {
if (it->has_value()) {
update(currentValue, it->value());
}
++it;
auto compareResult = candidateValue.compare(currentValue, kFlags).value();
if constexpr (isMax) {
return compareResult > 0;
} else {
return compareResult < 0;
}

assign(out, currentValue);
return true;
}
};

Expand Down
2 changes: 2 additions & 0 deletions velox/functions/sparksql/registration/RegisterArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ inline void registerArrayMinMaxFunctions(const std::string& prefix) {
registerArrayMinMaxFunctions<float>(prefix);
registerArrayMinMaxFunctions<double>(prefix);
registerArrayMinMaxFunctions<bool>(prefix);
registerArrayMinMaxFunctions<Varbinary>(prefix);
registerArrayMinMaxFunctions<Varchar>(prefix);
registerArrayMinMaxFunctions<Timestamp>(prefix);
registerArrayMinMaxFunctions<Date>(prefix);
registerArrayMinMaxFunctions<Orderable<T1>>(prefix);
}

template <typename T>
Expand Down
41 changes: 39 additions & 2 deletions velox/functions/sparksql/tests/ArrayMaxTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ namespace {
class ArrayMaxTest : public SparkFunctionBaseTest {
protected:
template <typename T>
std::optional<T> arrayMax(const std::vector<std::optional<T>>& input) {
std::optional<T> arrayMax(
const std::vector<std::optional<T>>& input,
const TypePtr& type = ARRAY(CppToType<T>::create())) {
auto row = makeRowVector({makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{input})});
std::vector<std::vector<std::optional<T>>>{input}, type)});
return evaluateOnce<T>("array_max(C0)", row);
}
};
Expand All @@ -47,6 +49,17 @@ TEST_F(ArrayMaxTest, boolean) {
EXPECT_EQ(arrayMax<bool>({true, true, true}), true);
}

TEST_F(ArrayMaxTest, varbinary) {
EXPECT_EQ(arrayMax<std::string>({"red", "blue"}, ARRAY(VARBINARY())), "red");
EXPECT_EQ(
arrayMax<std::string>(
{std::nullopt, "blue", "yellow", "orange"}, ARRAY(VARBINARY())),
"yellow");
EXPECT_EQ(arrayMax<std::string>({}, ARRAY(VARBINARY())), std::nullopt);
EXPECT_EQ(
arrayMax<std::string>({std::nullopt}, ARRAY(VARBINARY())), std::nullopt);
}

TEST_F(ArrayMaxTest, varchar) {
EXPECT_EQ(arrayMax<std::string>({"red", "blue"}), "red");
EXPECT_EQ(
Expand Down Expand Up @@ -87,6 +100,30 @@ TEST_F(ArrayMaxTest, timestamp) {
EXPECT_EQ(arrayMax<Timestamp>({ts(0), std::nullopt}), ts(0));
}

TEST_F(ArrayMaxTest, complexTypes) {
auto testExpression = [&](const VectorPtr& input, const VectorPtr& expected) {
auto result = evaluate("array_max(c0)", makeRowVector({input}));
assertEqualVectors(expected, result);
};
testExpression(
makeNestedArrayVectorFromJson<int64_t>(
{"[[1, 1, 1], [1, 2, 2], [1, 3, 1]]"}),
makeArrayVectorFromJson<int64_t>({"[1, 3, 1]"}));

testExpression(
makeNestedArrayVectorFromJson<double>(
{"[[2.0, null], [null, 2.0], [NaN, 1.0]]"}),
makeArrayVectorFromJson<double>({"[NaN, 1.0]"}));

testExpression(
makeNestedArrayVectorFromJson<int64_t>({"[[1, null], [null, 2], [1]]"}),
makeArrayVectorFromJson<int64_t>({"[1, null]"}));

testExpression(
makeNestedArrayVectorFromJson<int64_t>({"[null, null]"}),
makeArrayVectorFromJson<int64_t>({"null"}));
}

template <typename Type>
class ArrayMaxIntegralTest : public ArrayMaxTest {
public:
Expand Down
47 changes: 44 additions & 3 deletions velox/functions/sparksql/tests/ArrayMinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ namespace {
class ArrayMinTest : public SparkFunctionBaseTest {
protected:
template <typename T>
std::optional<T> arrayMin(const std::vector<std::optional<T>>& input) {
std::optional<T> arrayMin(
const std::vector<std::optional<T>>& input,
const TypePtr& type = ARRAY(CppToType<T>::create())) {
auto row = makeRowVector({makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{input})});
std::vector<std::vector<std::optional<T>>>{input}, type)});
return evaluateOnce<T>("array_min(C0)", row);
}
};
Expand All @@ -46,7 +48,18 @@ TEST_F(ArrayMinTest, boolean) {
EXPECT_EQ(arrayMin<bool>({std::nullopt, true, false, true}), false);
EXPECT_EQ(arrayMin<bool>({false, false, false}), false);
EXPECT_EQ(arrayMin<bool>({true, true, true}), true);
} // namespace
}

TEST_F(ArrayMinTest, varbinary) {
EXPECT_EQ(arrayMin<std::string>({"red", "blue"}, ARRAY(VARBINARY())), "blue");
EXPECT_EQ(
arrayMin<std::string>(
{std::nullopt, "blue", "yellow", "orange"}, ARRAY(VARBINARY())),
"blue");
EXPECT_EQ(arrayMin<std::string>({}, ARRAY(VARBINARY())), std::nullopt);
EXPECT_EQ(
arrayMin<std::string>({std::nullopt}, ARRAY(VARBINARY())), std::nullopt);
}

TEST_F(ArrayMinTest, varchar) {
EXPECT_EQ(arrayMin<std::string>({"red", "blue"}), "blue");
Expand Down Expand Up @@ -88,6 +101,34 @@ TEST_F(ArrayMinTest, timestamp) {
EXPECT_EQ(arrayMin<Timestamp>({ts(0), std::nullopt}), ts(0));
}

TEST_F(ArrayMinTest, complexTypes) {
auto testExpression = [&](const VectorPtr& input, const VectorPtr& expected) {
auto result = evaluate("array_min(c0)", makeRowVector({input}));
assertEqualVectors(expected, result);
};
testExpression(
makeNestedArrayVectorFromJson<int64_t>(
{"[[1, 1, 1], [1, 1, 2], [1, 3, 1]]"}),
makeArrayVectorFromJson<int64_t>({"[1, 1, 1]"}));

testExpression(
makeNestedArrayVectorFromJson<int64_t>(
{"[[1, null], [null, 2], [null, null]]"}),
makeArrayVectorFromJson<int64_t>({"[null, null]"}));

testExpression(
makeNestedArrayVectorFromJson<int64_t>({"[[1, null], [1, 2], [1]]"}),
makeArrayVectorFromJson<int64_t>({"[1]"}));

testExpression(
makeNestedArrayVectorFromJson<int64_t>({"[[1], [null], []]"}),
makeArrayVectorFromJson<int64_t>({"[]"}));

testExpression(
makeNestedArrayVectorFromJson<int64_t>({"[null, null]"}),
makeArrayVectorFromJson<int64_t>({"null"}));
}

template <typename Type>
class ArrayMinIntegralTest : public ArrayMinTest {
public:
Expand Down
Loading