|
17 | 17 | #include <gtest/gtest.h>
|
18 | 18 | #include <optional>
|
19 | 19 |
|
| 20 | +#include "velox/common/base/tests/GTestUtils.h" |
20 | 21 | #include "velox/expression/VectorReaders.h"
|
21 | 22 | #include "velox/functions/Udf.h"
|
| 23 | +#include "velox/functions/prestosql/Comparisons.h" |
22 | 24 | #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
|
23 | 25 |
|
24 | 26 | namespace {
|
25 | 27 |
|
26 | 28 | using namespace facebook::velox;
|
| 29 | +using namespace facebook::velox::functions; |
| 30 | +using namespace facebook::velox::test; |
27 | 31 |
|
28 | 32 | DecodedVector* decode(DecodedVector& decoder, const BaseVector& vector) {
|
29 | 33 | SelectivityVector rows(vector.size());
|
@@ -145,6 +149,133 @@ class RowViewTest : public functions::test::FunctionBaseTest {
|
145 | 149 | }
|
146 | 150 | }
|
147 | 151 | }
|
| 152 | + |
| 153 | + void compareTest() { |
| 154 | + auto rowVector1 = makeRowVector( |
| 155 | + {makeNullableFlatVector<int32_t>({std::nullopt}), |
| 156 | + makeNullableFlatVector<float>({1.0})}); |
| 157 | + auto rowVector2 = makeRowVector( |
| 158 | + {makeNullableFlatVector<int32_t>({std::nullopt}), |
| 159 | + makeNullableFlatVector<float>({2.0})}); |
| 160 | + { |
| 161 | + DecodedVector decoded1; |
| 162 | + DecodedVector decoded2; |
| 163 | + |
| 164 | + exec::VectorReader<Row<int32_t, float>> reader1( |
| 165 | + decode(decoded1, *rowVector1)); |
| 166 | + exec::VectorReader<Row<int32_t, float>> reader2( |
| 167 | + decode(decoded2, *rowVector2)); |
| 168 | + |
| 169 | + ASSERT_TRUE(reader1.isSet(0)); |
| 170 | + ASSERT_TRUE(reader2.isSet(0)); |
| 171 | + auto l = read(reader1, 0); |
| 172 | + auto r = read(reader2, 0); |
| 173 | + // Default flag for all operators other than `==` is kNullAsIndeterminate |
| 174 | + VELOX_ASSERT_THROW(r < l, "Ordering nulls is not supported"); |
| 175 | + VELOX_ASSERT_THROW(r <= l, "Ordering nulls is not supported"); |
| 176 | + VELOX_ASSERT_THROW(r > l, "Ordering nulls is not supported"); |
| 177 | + VELOX_ASSERT_THROW(r >= l, "Ordering nulls is not supported"); |
| 178 | + |
| 179 | + // Default flag for `==` is kNullAsValue |
| 180 | + ASSERT_FALSE(r == l); |
| 181 | + |
| 182 | + // Test we can pass in a flag to change the behavior for compare |
| 183 | + ASSERT_LT( |
| 184 | + l.compare( |
| 185 | + r, |
| 186 | + CompareFlags::equality( |
| 187 | + CompareFlags::NullHandlingMode::kNullAsValue)), |
| 188 | + 0); |
| 189 | + } |
| 190 | + |
| 191 | + // Test indeterminate ROW<integer, float> = [null, 2.0] against |
| 192 | + // [null, 2.0] is indeterminate |
| 193 | + { |
| 194 | + auto rowVector = vectorMaker_.rowVector( |
| 195 | + {BaseVector::createNullConstant( |
| 196 | + ROW({{"a", INTEGER()}}), 1, pool_.get()), |
| 197 | + makeNullableFlatVector<float>({1.0})}); |
| 198 | + |
| 199 | + DecodedVector decoded1; |
| 200 | + exec::VectorReader<Row<int32_t, float>> reader1( |
| 201 | + decode(decoded1, *rowVector1)); |
| 202 | + ASSERT_TRUE(reader1.isSet(0)); |
| 203 | + auto l = read(reader1, 0); |
| 204 | + auto flags = CompareFlags::equality( |
| 205 | + CompareFlags::NullHandlingMode::kNullAsIndeterminate); |
| 206 | + ASSERT_EQ(l.compare(l, flags), kIndeterminate); |
| 207 | + } |
| 208 | + } |
| 209 | + |
| 210 | + void e2eComparisonTest() { |
| 211 | + auto lhs = makeRowVector( |
| 212 | + {makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6}), |
| 213 | + makeFlatVector<float>({1.0, 2.0, 3.0, 4.0, 6.0, 0.0})}); |
| 214 | + auto rhs = makeRowVector( |
| 215 | + {makeNullableFlatVector<int32_t>({5, 4, 3, 4, 5, 6}), |
| 216 | + makeFlatVector<float>({2.0, 2.0, 3.0, 4.0, 6.0, 1.1})}); |
| 217 | + |
| 218 | + registerFunction< |
| 219 | + EqFunction, |
| 220 | + bool, |
| 221 | + Row<int32_t, float>, |
| 222 | + Row<int32_t, float>>({"row_eq"}); |
| 223 | + auto result = |
| 224 | + evaluate<FlatVector<bool>>("row_eq(c0, c1)", makeRowVector({lhs, rhs})); |
| 225 | + assertEqualVectors( |
| 226 | + makeFlatVector<bool>({false, false, true, true, true, false}), result); |
| 227 | + |
| 228 | + registerFunction< |
| 229 | + NeqFunction, |
| 230 | + bool, |
| 231 | + Row<int32_t, float>, |
| 232 | + Row<int32_t, float>>({"row_neq"}); |
| 233 | + result = evaluate<FlatVector<bool>>( |
| 234 | + "row_neq(c0, c1)", makeRowVector({lhs, rhs})); |
| 235 | + assertEqualVectors( |
| 236 | + makeFlatVector<bool>({true, true, false, false, false, true}), result); |
| 237 | + |
| 238 | + registerFunction< |
| 239 | + LtFunction, |
| 240 | + bool, |
| 241 | + Row<int32_t, float>, |
| 242 | + Row<int32_t, float>>({"row_lt"}); |
| 243 | + result = |
| 244 | + evaluate<FlatVector<bool>>("row_lt(c0, c1)", makeRowVector({lhs, rhs})); |
| 245 | + assertEqualVectors( |
| 246 | + makeFlatVector<bool>({true, true, false, false, false, true}), result); |
| 247 | + |
| 248 | + registerFunction< |
| 249 | + GtFunction, |
| 250 | + bool, |
| 251 | + Row<int32_t, float>, |
| 252 | + Row<int32_t, float>>({"row_gt"}); |
| 253 | + result = |
| 254 | + evaluate<FlatVector<bool>>("row_gt(c0, c1)", makeRowVector({lhs, rhs})); |
| 255 | + assertEqualVectors( |
| 256 | + makeFlatVector<bool>({false, false, false, false, false, false}), |
| 257 | + result); |
| 258 | + |
| 259 | + registerFunction< |
| 260 | + LteFunction, |
| 261 | + bool, |
| 262 | + Row<int32_t, float>, |
| 263 | + Row<int32_t, float>>({"row_lte"}); |
| 264 | + result = evaluate<FlatVector<bool>>( |
| 265 | + "row_lte(c0, c1)", makeRowVector({lhs, rhs})); |
| 266 | + assertEqualVectors( |
| 267 | + makeFlatVector<bool>({true, true, true, true, true, true}), result); |
| 268 | + |
| 269 | + registerFunction< |
| 270 | + GteFunction, |
| 271 | + bool, |
| 272 | + Row<int32_t, float>, |
| 273 | + Row<int32_t, float>>({"row_gte"}); |
| 274 | + result = evaluate<FlatVector<bool>>( |
| 275 | + "row_gte(c0, c1)", makeRowVector({lhs, rhs})); |
| 276 | + assertEqualVectors( |
| 277 | + makeFlatVector<bool>({false, false, true, true, true, false}), result); |
| 278 | + } |
148 | 279 | };
|
149 | 280 |
|
150 | 281 | class NullableRowViewTest : public RowViewTest<true> {};
|
@@ -188,6 +319,13 @@ TEST_F(NullFreeRowViewTest, materialize) {
|
188 | 319 | 1, "hi", {1, 2, 3}};
|
189 | 320 | ASSERT_EQ(reader.readNullFree(0).materialize(), expected);
|
190 | 321 | }
|
| 322 | +TEST_F(NullFreeRowViewTest, compare) { |
| 323 | + compareTest(); |
| 324 | +} |
| 325 | + |
| 326 | +TEST_F(NullFreeRowViewTest, e2eCompare) { |
| 327 | + e2eComparisonTest(); |
| 328 | +} |
191 | 329 |
|
192 | 330 | TEST_F(NullableRowViewTest, materialize) {
|
193 | 331 | auto result = evaluate(
|
@@ -299,16 +437,16 @@ TEST_F(DynamicRowViewTest, castToDynamicRowInFunction) {
|
299 | 437 |
|
300 | 438 | // Input is not struct.
|
301 | 439 | auto result = evaluate("struct_width(c0)", makeRowVector({flatVector}));
|
302 |
| - test::assertEqualVectors(makeFlatVector<int64_t>({0, 0}), result); |
| 440 | + assertEqualVectors(makeFlatVector<int64_t>({0, 0}), result); |
303 | 441 |
|
304 | 442 | result = evaluate(
|
305 | 443 | "struct_width(c0)", makeRowVector({makeRowVector({flatVector})}));
|
306 |
| - test::assertEqualVectors(makeFlatVector<int64_t>({1, 1}), result); |
| 444 | + assertEqualVectors(makeFlatVector<int64_t>({1, 1}), result); |
307 | 445 |
|
308 | 446 | result = evaluate(
|
309 | 447 | "struct_width(c0)",
|
310 | 448 | makeRowVector({makeRowVector({flatVector, flatVector})}));
|
311 |
| - test::assertEqualVectors(makeFlatVector<int64_t>({2, 2}), result); |
| 449 | + assertEqualVectors(makeFlatVector<int64_t>({2, 2}), result); |
312 | 450 | }
|
313 | 451 | }
|
314 | 452 | } // namespace
|
0 commit comments