Skip to content

Commit

Permalink
Add comparison support for RowView (facebookincubator#11499)
Browse files Browse the repository at this point in the history
Summary:

Support comparison for RowView. This will allow us to compare IPPrefix which has an underlying type of Row<int128_t, int8_t>.

When doing between and other comparison operations, RowView's comparisons are not implemented. We can extend RowView's comparison similar to GenericView.

I introduce a base class which does the comparisons. The classes which implement the base class can specialize how they want to implement compare.

For RowView, I iterate through the underlying tuple 1 by 1 until we find the first match where the underlying RowVector returns a non-zero comparison.

Differential Revision: D65700875
  • Loading branch information
yuandagits authored and facebook-github-bot committed Nov 11, 2024
1 parent 57040fd commit d7c5383
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 38 deletions.
119 changes: 84 additions & 35 deletions velox/expression/ComplexViewTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,59 @@

namespace facebook::velox::exec {

/// Base class for views that need comparison. Default comparison is forbidden
/// and this class requires specialization. For now, defaulting the
/// == flag to be kNullAsValue and other comparison flag to be
/// kNullAsIndeterminate. Can adjust in the future to be more configurable.
template <typename T>
struct BaseView {
std::optional<int64_t> compare(
const T& /*other*/,
const CompareFlags /*flags*/) const {
VELOX_UNSUPPORTED("Must provide specialization");
}

bool operator==(const T& other) const {
static constexpr auto kEqualValueAtFlags =
CompareFlags::equality(CompareFlags::NullHandlingMode::kNullAsValue);
return this->compareOrThrow(other, kEqualValueAtFlags) == 0;
}

bool operator<(const T& other) const {
return this->compareOrThrow(other) < 0;
}

bool operator<=(const T& other) const {
return this->compareOrThrow(other) <= 0;
}

bool operator>(const T& other) const {
return this->compareOrThrow(other) > 0;
}

bool operator>=(const T& other) const {
return this->compareOrThrow(other) >= 0;
}

bool operator!=(const T& other) const {
return this->compareOrThrow(other) != 0;
}

private:
int64_t compareOrThrow(
const T& other,
CompareFlags flags = CompareFlags{
.nullHandlingMode =
CompareFlags::NullHandlingMode::kNullAsIndeterminate}) const {
auto result = static_cast<const T*>(this)->compare(other, flags);
// Will throw if it encounters null elements before result is determined.
VELOX_DCHECK(
result.has_value(),
"Compare should have thrown when null is encountered in child.");
return result.value();
}
};

template <typename T>
struct VectorReader;

Expand Down Expand Up @@ -927,7 +980,7 @@ class DynamicRowView {
};

template <bool returnsOptionalValues, typename... T>
class RowView {
class RowView : public BaseView<RowView<returnsOptionalValues, T...>> {
using reader_t = std::tuple<std::unique_ptr<VectorReader<T>>...>;
using types = std::tuple<T...>;

Expand Down Expand Up @@ -967,11 +1020,40 @@ class RowView {
return result;
}

std::optional<int64_t> compare(const RowView& other, const CompareFlags flags)
const {
return compareImpl(other, flags);
}

private:
void initialize() {
initializeImpl(std::index_sequence_for<T...>());
}

template <std::size_t Is = 0>
std::optional<int64_t> compareImpl(
const RowView& other,
const CompareFlags flags) const {
if constexpr (Is < sizeof...(T)) {
auto result = std::get<Is>(*childReaders_)
->baseVector()
->compare(
std::get<Is>(*other.childReaders_)->baseVector(),
offset_,
other.offset_,
flags);
if (!result.has_value()) {
return std::nullopt;
}
if (result.value() != 0) {
return result.value();
}

return compareImpl<Is + 1>(other, flags);
}
return 0;
}

using children_types = std::tuple<T...>;
template <std::size_t... Is>
void materializeImpl(materialize_t& result, std::index_sequence<Is...>)
Expand Down Expand Up @@ -1068,7 +1150,7 @@ struct AllGenericExceptTop<Row<T...>> {
}
};

class GenericView {
class GenericView : public BaseView<GenericView> {
public:
GenericView(
const DecodedVector& decoded,
Expand All @@ -1092,39 +1174,6 @@ class GenericView {
return decoded_.base();
}

bool operator==(const GenericView& other) const {
return decoded_.base()->equalValueAt(
other.decoded_.base(), decodedIndex(), other.decodedIndex());
}

int64_t compareOrThrow(const GenericView& other) const {
static constexpr CompareFlags kFlags = {
.nullHandlingMode =
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
std::optional<int64_t> result = this->compare(other, kFlags);
// Will throw if it encounters null elements before result is determined.
VELOX_DCHECK(
result.has_value(),
"Compare should have thrown when null is encountered in child.");
return result.value();
}

bool operator<(const GenericView& other) const {
return compareOrThrow(other) < 0;
}

bool operator<=(const GenericView& other) const {
return compareOrThrow(other) <= 0;
}

bool operator>(const GenericView& other) const {
return compareOrThrow(other) > 0;
}

bool operator>=(const GenericView& other) const {
return compareOrThrow(other) >= 0;
}

vector_size_t decodedIndex() const {
return decoded_.index(index_);
}
Expand Down
144 changes: 141 additions & 3 deletions velox/expression/tests/RowViewTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
#include <gtest/gtest.h>
#include <optional>

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/expression/VectorReaders.h"
#include "velox/functions/Udf.h"
#include "velox/functions/prestosql/Comparisons.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"

namespace {

using namespace facebook::velox;
using namespace facebook::velox::functions;
using namespace facebook::velox::test;

DecodedVector* decode(DecodedVector& decoder, const BaseVector& vector) {
SelectivityVector rows(vector.size());
Expand Down Expand Up @@ -145,6 +149,133 @@ class RowViewTest : public functions::test::FunctionBaseTest {
}
}
}

void compareTest() {
auto rowVector1 = makeRowVector(
{makeNullableFlatVector<int32_t>({std::nullopt}),
makeNullableFlatVector<float>({1.0})});
auto rowVector2 = makeRowVector(
{makeNullableFlatVector<int32_t>({std::nullopt}),
makeNullableFlatVector<float>({2.0})});
{
DecodedVector decoded1;
DecodedVector decoded2;

exec::VectorReader<Row<int32_t, float>> reader1(
decode(decoded1, *rowVector1));
exec::VectorReader<Row<int32_t, float>> reader2(
decode(decoded2, *rowVector2));

ASSERT_TRUE(reader1.isSet(0));
ASSERT_TRUE(reader2.isSet(0));
auto l = read(reader1, 0);
auto r = read(reader2, 0);
// Default flag for all operators other than `==` is kNullAsIndeterminate
VELOX_ASSERT_THROW(r < l, "Ordering nulls is not supported");
VELOX_ASSERT_THROW(r <= l, "Ordering nulls is not supported");
VELOX_ASSERT_THROW(r > l, "Ordering nulls is not supported");
VELOX_ASSERT_THROW(r >= l, "Ordering nulls is not supported");

// Default flag for `==` is kNullAsValue
ASSERT_FALSE(r == l);

// Test we can pass in a flag to change the behavior for compare
ASSERT_LT(
l.compare(
r,
CompareFlags::equality(
CompareFlags::NullHandlingMode::kNullAsValue)),
0);
}

// Test indeterminate ROW<integer, float> = [null, 2.0] against
// [null, 2.0] is indeterminate
{
auto rowVector = vectorMaker_.rowVector(
{BaseVector::createNullConstant(
ROW({{"a", INTEGER()}}), 1, pool_.get()),
makeNullableFlatVector<float>({1.0})});

DecodedVector decoded1;
exec::VectorReader<Row<int32_t, float>> reader1(
decode(decoded1, *rowVector1));
ASSERT_TRUE(reader1.isSet(0));
auto l = read(reader1, 0);
auto flags = CompareFlags::equality(
CompareFlags::NullHandlingMode::kNullAsIndeterminate);
ASSERT_EQ(l.compare(l, flags), kIndeterminate);
}
}

void e2eComparisonTest() {
auto lhs = makeRowVector(
{makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6}),
makeFlatVector<float>({1.0, 2.0, 3.0, 4.0, 6.0, 0.0})});
auto rhs = makeRowVector(
{makeNullableFlatVector<int32_t>({5, 4, 3, 4, 5, 6}),
makeFlatVector<float>({2.0, 2.0, 3.0, 4.0, 6.0, 1.1})});

registerFunction<
EqFunction,
bool,
Row<int32_t, float>,
Row<int32_t, float>>({"row_eq"});
auto result =
evaluate<FlatVector<bool>>("row_eq(c0, c1)", makeRowVector({lhs, rhs}));
assertEqualVectors(
makeFlatVector<bool>({false, false, true, true, true, false}), result);

registerFunction<
NeqFunction,
bool,
Row<int32_t, float>,
Row<int32_t, float>>({"row_neq"});
result = evaluate<FlatVector<bool>>(
"row_neq(c0, c1)", makeRowVector({lhs, rhs}));
assertEqualVectors(
makeFlatVector<bool>({true, true, false, false, false, true}), result);

registerFunction<
LtFunction,
bool,
Row<int32_t, float>,
Row<int32_t, float>>({"row_lt"});
result =
evaluate<FlatVector<bool>>("row_lt(c0, c1)", makeRowVector({lhs, rhs}));
assertEqualVectors(
makeFlatVector<bool>({true, true, false, false, false, true}), result);

registerFunction<
GtFunction,
bool,
Row<int32_t, float>,
Row<int32_t, float>>({"row_gt"});
result =
evaluate<FlatVector<bool>>("row_gt(c0, c1)", makeRowVector({lhs, rhs}));
assertEqualVectors(
makeFlatVector<bool>({false, false, false, false, false, false}),
result);

registerFunction<
LteFunction,
bool,
Row<int32_t, float>,
Row<int32_t, float>>({"row_lte"});
result = evaluate<FlatVector<bool>>(
"row_lte(c0, c1)", makeRowVector({lhs, rhs}));
assertEqualVectors(
makeFlatVector<bool>({true, true, true, true, true, true}), result);

registerFunction<
GteFunction,
bool,
Row<int32_t, float>,
Row<int32_t, float>>({"row_gte"});
result = evaluate<FlatVector<bool>>(
"row_gte(c0, c1)", makeRowVector({lhs, rhs}));
assertEqualVectors(
makeFlatVector<bool>({false, false, true, true, true, false}), result);
}
};

class NullableRowViewTest : public RowViewTest<true> {};
Expand Down Expand Up @@ -188,6 +319,13 @@ TEST_F(NullFreeRowViewTest, materialize) {
1, "hi", {1, 2, 3}};
ASSERT_EQ(reader.readNullFree(0).materialize(), expected);
}
TEST_F(NullFreeRowViewTest, compare) {
compareTest();
}

TEST_F(NullFreeRowViewTest, e2eCompare) {
e2eComparisonTest();
}

TEST_F(NullableRowViewTest, materialize) {
auto result = evaluate(
Expand Down Expand Up @@ -299,16 +437,16 @@ TEST_F(DynamicRowViewTest, castToDynamicRowInFunction) {

// Input is not struct.
auto result = evaluate("struct_width(c0)", makeRowVector({flatVector}));
test::assertEqualVectors(makeFlatVector<int64_t>({0, 0}), result);
assertEqualVectors(makeFlatVector<int64_t>({0, 0}), result);

result = evaluate(
"struct_width(c0)", makeRowVector({makeRowVector({flatVector})}));
test::assertEqualVectors(makeFlatVector<int64_t>({1, 1}), result);
assertEqualVectors(makeFlatVector<int64_t>({1, 1}), result);

result = evaluate(
"struct_width(c0)",
makeRowVector({makeRowVector({flatVector, flatVector})}));
test::assertEqualVectors(makeFlatVector<int64_t>({2, 2}), result);
assertEqualVectors(makeFlatVector<int64_t>({2, 2}), result);
}
}
} // namespace

0 comments on commit d7c5383

Please sign in to comment.