Skip to content

Commit d7c5383

Browse files
yuandagitsfacebook-github-bot
authored andcommitted
Add comparison support for RowView (facebookincubator#11499)
Summary: Support comparison for RowView. This will allow us to compare IPPrefix which has an underlying type of Row<int128_t, int8_t>. When doing between and other comparison operations, RowView's comparisons are not implemented. We can extend RowView's comparison similar to GenericView. I introduce a base class which does the comparisons. The classes which implement the base class can specialize how they want to implement compare. For RowView, I iterate through the underlying tuple 1 by 1 until we find the first match where the underlying RowVector returns a non-zero comparison. Differential Revision: D65700875
1 parent 57040fd commit d7c5383

File tree

2 files changed

+225
-38
lines changed

2 files changed

+225
-38
lines changed

velox/expression/ComplexViewTypes.h

Lines changed: 84 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,59 @@
3131

3232
namespace facebook::velox::exec {
3333

34+
/// Base class for views that need comparison. Default comparison is forbidden
35+
/// and this class requires specialization. For now, defaulting the
36+
/// == flag to be kNullAsValue and other comparison flag to be
37+
/// kNullAsIndeterminate. Can adjust in the future to be more configurable.
38+
template <typename T>
39+
struct BaseView {
40+
std::optional<int64_t> compare(
41+
const T& /*other*/,
42+
const CompareFlags /*flags*/) const {
43+
VELOX_UNSUPPORTED("Must provide specialization");
44+
}
45+
46+
bool operator==(const T& other) const {
47+
static constexpr auto kEqualValueAtFlags =
48+
CompareFlags::equality(CompareFlags::NullHandlingMode::kNullAsValue);
49+
return this->compareOrThrow(other, kEqualValueAtFlags) == 0;
50+
}
51+
52+
bool operator<(const T& other) const {
53+
return this->compareOrThrow(other) < 0;
54+
}
55+
56+
bool operator<=(const T& other) const {
57+
return this->compareOrThrow(other) <= 0;
58+
}
59+
60+
bool operator>(const T& other) const {
61+
return this->compareOrThrow(other) > 0;
62+
}
63+
64+
bool operator>=(const T& other) const {
65+
return this->compareOrThrow(other) >= 0;
66+
}
67+
68+
bool operator!=(const T& other) const {
69+
return this->compareOrThrow(other) != 0;
70+
}
71+
72+
private:
73+
int64_t compareOrThrow(
74+
const T& other,
75+
CompareFlags flags = CompareFlags{
76+
.nullHandlingMode =
77+
CompareFlags::NullHandlingMode::kNullAsIndeterminate}) const {
78+
auto result = static_cast<const T*>(this)->compare(other, flags);
79+
// Will throw if it encounters null elements before result is determined.
80+
VELOX_DCHECK(
81+
result.has_value(),
82+
"Compare should have thrown when null is encountered in child.");
83+
return result.value();
84+
}
85+
};
86+
3487
template <typename T>
3588
struct VectorReader;
3689

@@ -927,7 +980,7 @@ class DynamicRowView {
927980
};
928981

929982
template <bool returnsOptionalValues, typename... T>
930-
class RowView {
983+
class RowView : public BaseView<RowView<returnsOptionalValues, T...>> {
931984
using reader_t = std::tuple<std::unique_ptr<VectorReader<T>>...>;
932985
using types = std::tuple<T...>;
933986

@@ -967,11 +1020,40 @@ class RowView {
9671020
return result;
9681021
}
9691022

1023+
std::optional<int64_t> compare(const RowView& other, const CompareFlags flags)
1024+
const {
1025+
return compareImpl(other, flags);
1026+
}
1027+
9701028
private:
9711029
void initialize() {
9721030
initializeImpl(std::index_sequence_for<T...>());
9731031
}
9741032

1033+
template <std::size_t Is = 0>
1034+
std::optional<int64_t> compareImpl(
1035+
const RowView& other,
1036+
const CompareFlags flags) const {
1037+
if constexpr (Is < sizeof...(T)) {
1038+
auto result = std::get<Is>(*childReaders_)
1039+
->baseVector()
1040+
->compare(
1041+
std::get<Is>(*other.childReaders_)->baseVector(),
1042+
offset_,
1043+
other.offset_,
1044+
flags);
1045+
if (!result.has_value()) {
1046+
return std::nullopt;
1047+
}
1048+
if (result.value() != 0) {
1049+
return result.value();
1050+
}
1051+
1052+
return compareImpl<Is + 1>(other, flags);
1053+
}
1054+
return 0;
1055+
}
1056+
9751057
using children_types = std::tuple<T...>;
9761058
template <std::size_t... Is>
9771059
void materializeImpl(materialize_t& result, std::index_sequence<Is...>)
@@ -1068,7 +1150,7 @@ struct AllGenericExceptTop<Row<T...>> {
10681150
}
10691151
};
10701152

1071-
class GenericView {
1153+
class GenericView : public BaseView<GenericView> {
10721154
public:
10731155
GenericView(
10741156
const DecodedVector& decoded,
@@ -1092,39 +1174,6 @@ class GenericView {
10921174
return decoded_.base();
10931175
}
10941176

1095-
bool operator==(const GenericView& other) const {
1096-
return decoded_.base()->equalValueAt(
1097-
other.decoded_.base(), decodedIndex(), other.decodedIndex());
1098-
}
1099-
1100-
int64_t compareOrThrow(const GenericView& other) const {
1101-
static constexpr CompareFlags kFlags = {
1102-
.nullHandlingMode =
1103-
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
1104-
std::optional<int64_t> result = this->compare(other, kFlags);
1105-
// Will throw if it encounters null elements before result is determined.
1106-
VELOX_DCHECK(
1107-
result.has_value(),
1108-
"Compare should have thrown when null is encountered in child.");
1109-
return result.value();
1110-
}
1111-
1112-
bool operator<(const GenericView& other) const {
1113-
return compareOrThrow(other) < 0;
1114-
}
1115-
1116-
bool operator<=(const GenericView& other) const {
1117-
return compareOrThrow(other) <= 0;
1118-
}
1119-
1120-
bool operator>(const GenericView& other) const {
1121-
return compareOrThrow(other) > 0;
1122-
}
1123-
1124-
bool operator>=(const GenericView& other) const {
1125-
return compareOrThrow(other) >= 0;
1126-
}
1127-
11281177
vector_size_t decodedIndex() const {
11291178
return decoded_.index(index_);
11301179
}

velox/expression/tests/RowViewTest.cpp

Lines changed: 141 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717
#include <gtest/gtest.h>
1818
#include <optional>
1919

20+
#include "velox/common/base/tests/GTestUtils.h"
2021
#include "velox/expression/VectorReaders.h"
2122
#include "velox/functions/Udf.h"
23+
#include "velox/functions/prestosql/Comparisons.h"
2224
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
2325

2426
namespace {
2527

2628
using namespace facebook::velox;
29+
using namespace facebook::velox::functions;
30+
using namespace facebook::velox::test;
2731

2832
DecodedVector* decode(DecodedVector& decoder, const BaseVector& vector) {
2933
SelectivityVector rows(vector.size());
@@ -145,6 +149,133 @@ class RowViewTest : public functions::test::FunctionBaseTest {
145149
}
146150
}
147151
}
152+
153+
void compareTest() {
154+
auto rowVector1 = makeRowVector(
155+
{makeNullableFlatVector<int32_t>({std::nullopt}),
156+
makeNullableFlatVector<float>({1.0})});
157+
auto rowVector2 = makeRowVector(
158+
{makeNullableFlatVector<int32_t>({std::nullopt}),
159+
makeNullableFlatVector<float>({2.0})});
160+
{
161+
DecodedVector decoded1;
162+
DecodedVector decoded2;
163+
164+
exec::VectorReader<Row<int32_t, float>> reader1(
165+
decode(decoded1, *rowVector1));
166+
exec::VectorReader<Row<int32_t, float>> reader2(
167+
decode(decoded2, *rowVector2));
168+
169+
ASSERT_TRUE(reader1.isSet(0));
170+
ASSERT_TRUE(reader2.isSet(0));
171+
auto l = read(reader1, 0);
172+
auto r = read(reader2, 0);
173+
// Default flag for all operators other than `==` is kNullAsIndeterminate
174+
VELOX_ASSERT_THROW(r < l, "Ordering nulls is not supported");
175+
VELOX_ASSERT_THROW(r <= l, "Ordering nulls is not supported");
176+
VELOX_ASSERT_THROW(r > l, "Ordering nulls is not supported");
177+
VELOX_ASSERT_THROW(r >= l, "Ordering nulls is not supported");
178+
179+
// Default flag for `==` is kNullAsValue
180+
ASSERT_FALSE(r == l);
181+
182+
// Test we can pass in a flag to change the behavior for compare
183+
ASSERT_LT(
184+
l.compare(
185+
r,
186+
CompareFlags::equality(
187+
CompareFlags::NullHandlingMode::kNullAsValue)),
188+
0);
189+
}
190+
191+
// Test indeterminate ROW<integer, float> = [null, 2.0] against
192+
// [null, 2.0] is indeterminate
193+
{
194+
auto rowVector = vectorMaker_.rowVector(
195+
{BaseVector::createNullConstant(
196+
ROW({{"a", INTEGER()}}), 1, pool_.get()),
197+
makeNullableFlatVector<float>({1.0})});
198+
199+
DecodedVector decoded1;
200+
exec::VectorReader<Row<int32_t, float>> reader1(
201+
decode(decoded1, *rowVector1));
202+
ASSERT_TRUE(reader1.isSet(0));
203+
auto l = read(reader1, 0);
204+
auto flags = CompareFlags::equality(
205+
CompareFlags::NullHandlingMode::kNullAsIndeterminate);
206+
ASSERT_EQ(l.compare(l, flags), kIndeterminate);
207+
}
208+
}
209+
210+
void e2eComparisonTest() {
211+
auto lhs = makeRowVector(
212+
{makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6}),
213+
makeFlatVector<float>({1.0, 2.0, 3.0, 4.0, 6.0, 0.0})});
214+
auto rhs = makeRowVector(
215+
{makeNullableFlatVector<int32_t>({5, 4, 3, 4, 5, 6}),
216+
makeFlatVector<float>({2.0, 2.0, 3.0, 4.0, 6.0, 1.1})});
217+
218+
registerFunction<
219+
EqFunction,
220+
bool,
221+
Row<int32_t, float>,
222+
Row<int32_t, float>>({"row_eq"});
223+
auto result =
224+
evaluate<FlatVector<bool>>("row_eq(c0, c1)", makeRowVector({lhs, rhs}));
225+
assertEqualVectors(
226+
makeFlatVector<bool>({false, false, true, true, true, false}), result);
227+
228+
registerFunction<
229+
NeqFunction,
230+
bool,
231+
Row<int32_t, float>,
232+
Row<int32_t, float>>({"row_neq"});
233+
result = evaluate<FlatVector<bool>>(
234+
"row_neq(c0, c1)", makeRowVector({lhs, rhs}));
235+
assertEqualVectors(
236+
makeFlatVector<bool>({true, true, false, false, false, true}), result);
237+
238+
registerFunction<
239+
LtFunction,
240+
bool,
241+
Row<int32_t, float>,
242+
Row<int32_t, float>>({"row_lt"});
243+
result =
244+
evaluate<FlatVector<bool>>("row_lt(c0, c1)", makeRowVector({lhs, rhs}));
245+
assertEqualVectors(
246+
makeFlatVector<bool>({true, true, false, false, false, true}), result);
247+
248+
registerFunction<
249+
GtFunction,
250+
bool,
251+
Row<int32_t, float>,
252+
Row<int32_t, float>>({"row_gt"});
253+
result =
254+
evaluate<FlatVector<bool>>("row_gt(c0, c1)", makeRowVector({lhs, rhs}));
255+
assertEqualVectors(
256+
makeFlatVector<bool>({false, false, false, false, false, false}),
257+
result);
258+
259+
registerFunction<
260+
LteFunction,
261+
bool,
262+
Row<int32_t, float>,
263+
Row<int32_t, float>>({"row_lte"});
264+
result = evaluate<FlatVector<bool>>(
265+
"row_lte(c0, c1)", makeRowVector({lhs, rhs}));
266+
assertEqualVectors(
267+
makeFlatVector<bool>({true, true, true, true, true, true}), result);
268+
269+
registerFunction<
270+
GteFunction,
271+
bool,
272+
Row<int32_t, float>,
273+
Row<int32_t, float>>({"row_gte"});
274+
result = evaluate<FlatVector<bool>>(
275+
"row_gte(c0, c1)", makeRowVector({lhs, rhs}));
276+
assertEqualVectors(
277+
makeFlatVector<bool>({false, false, true, true, true, false}), result);
278+
}
148279
};
149280

150281
class NullableRowViewTest : public RowViewTest<true> {};
@@ -188,6 +319,13 @@ TEST_F(NullFreeRowViewTest, materialize) {
188319
1, "hi", {1, 2, 3}};
189320
ASSERT_EQ(reader.readNullFree(0).materialize(), expected);
190321
}
322+
TEST_F(NullFreeRowViewTest, compare) {
323+
compareTest();
324+
}
325+
326+
TEST_F(NullFreeRowViewTest, e2eCompare) {
327+
e2eComparisonTest();
328+
}
191329

192330
TEST_F(NullableRowViewTest, materialize) {
193331
auto result = evaluate(
@@ -299,16 +437,16 @@ TEST_F(DynamicRowViewTest, castToDynamicRowInFunction) {
299437

300438
// Input is not struct.
301439
auto result = evaluate("struct_width(c0)", makeRowVector({flatVector}));
302-
test::assertEqualVectors(makeFlatVector<int64_t>({0, 0}), result);
440+
assertEqualVectors(makeFlatVector<int64_t>({0, 0}), result);
303441

304442
result = evaluate(
305443
"struct_width(c0)", makeRowVector({makeRowVector({flatVector})}));
306-
test::assertEqualVectors(makeFlatVector<int64_t>({1, 1}), result);
444+
assertEqualVectors(makeFlatVector<int64_t>({1, 1}), result);
307445

308446
result = evaluate(
309447
"struct_width(c0)",
310448
makeRowVector({makeRowVector({flatVector, flatVector})}));
311-
test::assertEqualVectors(makeFlatVector<int64_t>({2, 2}), result);
449+
assertEqualVectors(makeFlatVector<int64_t>({2, 2}), result);
312450
}
313451
}
314452
} // namespace

0 commit comments

Comments
 (0)