diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 1103d0246452..f614220fc89c 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -47,14 +47,14 @@ jobs: INSTALL_PREFIX: ${{ github.workspace }}/dependencies steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox persist-credentials: false - name: Restore Dependencies - uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 id: restore-deps with: path: ${{ env.INSTALL_PREFIX }} @@ -84,7 +84,7 @@ jobs: - name: Save Dependencies if: ${{ steps.restore-deps.outputs.cache-hit != 'true' }} - uses: actions/cache/save@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + uses: actions/cache/save@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: ${{ env.INSTALL_PREFIX }} key: dependencies-benchmark-${{ hashFiles('velox/scripts/setup-ubuntu.sh') }} @@ -151,12 +151,12 @@ jobs: merge-multiple: true path: /tmp/artifacts/ - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox persist-credentials: false - - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: '3.10' cache: pip @@ -187,7 +187,7 @@ jobs: run: echo "failed=true" >> $GITHUB_OUTPUT - name: Create a GitHub Status on the contender commit (whether the upload was successful) - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 if: ${{ !cancelled() && steps.extract.conclusion != 'failure' }} with: script: | diff --git a/.github/workflows/breeze.yml b/.github/workflows/breeze.yml index 66eaaa480c18..d6df76dc41b8 100644 --- a/.github/workflows/breeze.yml +++ b/.github/workflows/breeze.yml @@ -54,13 +54,13 @@ jobs: working-directory: velox steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5.0.0 with: path: velox persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@4959332f0f014c5280e7eac8b70c90cb574c9f9b # v6.6.0 + uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6.8.0 - name: Install Dependencies run: | @@ -95,13 +95,13 @@ jobs: working-directory: velox steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5.0.0 with: path: velox persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@4959332f0f014c5280e7eac8b70c90cb574c9f9b # v6.6.0 + uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6.8.0 - name: Install Dependencies run: | diff --git a/.github/workflows/build-metrics.yml b/.github/workflows/build-metrics.yml index 2bf62b4352e2..4f45d1671d26 100644 --- a/.github/workflows/build-metrics.yml +++ b/.github/workflows/build-metrics.yml @@ -49,7 +49,7 @@ jobs: run: shell: bash steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5.0.0 with: ref: ${{ inputs.ref || github.sha }} persist-credentials: false @@ -145,7 +145,7 @@ jobs: needs: metrics steps: - name: Checkout - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: fetch-depth: 0 persist-credentials: true diff --git a/.github/workflows/build_pyvelox.yml b/.github/workflows/build_pyvelox.yml index 168e017e8022..5a397f431580 100644 --- a/.github/workflows/build_pyvelox.yml +++ b/.github/workflows/build_pyvelox.yml @@ -50,14 +50,14 @@ jobs: matrix: os: [8-core-ubuntu, macos-13, macos-14] steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: ref: ${{ inputs.ref || github.ref }} fetch-depth: 0 persist-credentials: false - - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: '3.10' @@ -162,12 +162,12 @@ jobs: - run: ls wheelhouse - - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: '3.10' - name: Publish a Python distribution to PyPI - uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 with: password: ${{ secrets.PYPI_API_TOKEN }} packages_dir: wheelhouse diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index b01751579c06..dae36dbcba13 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -55,16 +55,27 @@ jobs: steps: - name: Free Disk Space run: | - # 15G - sudo rm -rf /usr/local/lib/android || : - # 5.3GB - sudo rm -rf /opt/hostedtoolcache/CodeQL || : + # Re-used from free-disk-space github action. + getAvailableSpace() { echo $(df -a $1 | awk 'NR > 1 {avail+=$4} END {print avail}'); } + # Show before + echo "Original available disk space: " $(getAvailableSpace) + # Remove DotNet. + sudo rm -rf /usr/share/dotnet || true + # Remove android + sudo rm -rf /usr/local/lib/android || true + # Remove CodeQL + sudo rm -rf /opt/hostedtoolcache/CodeQL || true + # Remove Haskell + sudo rm -rf /opt/ghc || true + sudo rm -rf /usr/local/.ghcup || true + # Show after + echo "New available disk space: " $(getAvailableSpace) - name: Set up Docker Buildx uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1 - name: Login to GitHub Container Registry - uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: registry: ghcr.io username: ${{ github.actor }} @@ -130,7 +141,7 @@ jobs: packages: write steps: - name: Login to GitHub Container Registry - uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: registry: ghcr.io username: ${{ github.actor }} @@ -168,7 +179,7 @@ jobs: target: [java] steps: - name: Login to GitHub Container Registry - uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: registry: ghcr.io username: ${{ github.actor }} diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 523fda54dcf7..2a749fbd2a76 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -51,7 +51,7 @@ jobs: key: ccache-docs-8-core-ubuntu - name: Checkout - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: fetch-depth: 0 persist-credentials: true diff --git a/.github/workflows/linux-build-base.yml b/.github/workflows/linux-build-base.yml index 2f08b94319c5..48563fb1fba0 100644 --- a/.github/workflows/linux-build-base.yml +++ b/.github/workflows/linux-build-base.yml @@ -36,13 +36,12 @@ jobs: env: CCACHE_DIR: ${{ github.workspace }}/ccache VELOX_DEPENDENCY_SOURCE: SYSTEM - GTest_SOURCE: BUNDLED cudf_SOURCE: BUNDLED CUDA_VERSION: '12.8' faiss_SOURCE: BUNDLED USE_CLANG: "${{ inputs.use-clang && 'true' || 'false' }}" steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5.0.0 with: fetch-depth: 2 persist-credentials: false @@ -109,10 +108,10 @@ jobs: else # cuDF (unsupported for Clang) and Faiss (link issue when using Clang) # are excluded for Clang compilation and need to be added back when using GCC. - EXTRA_CMAKE_FLAGS+="-DVELOX_ENABLE_CUDF=ON" - EXTRA_CMAKE_FLAGS+="-DVELOX_ENABLE_FAISS=ON" + EXTRA_CMAKE_FLAGS+=("-DVELOX_ENABLE_CUDF=ON") + EXTRA_CMAKE_FLAGS+=("-DVELOX_ENABLE_FAISS=ON") # Investigate issues with remote function service: Issue #13897 - EXTRA_CMAKE_FLAGS+="-DVELOX_ENABLE_REMOTE_FUNCTIONS=ON" + EXTRA_CMAKE_FLAGS+=("-DVELOX_ENABLE_REMOTE_FUNCTIONS=ON") fi make release EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS[*]}" @@ -134,6 +133,8 @@ jobs: source "/opt/miniforge/etc/profile.d/conda.sh" conda activate adapters fi + # Needed for HADOOP 3.3.6 minicluster. Can remove after updating to 3.4.2. + wget https://repo1.maven.org/maven2/org/mockito/mockito-core/2.23.4/mockito-core-2.23.4.jar -O /usr/local/hadoop/share/hadoop/mapreduce/mockito-core-2.23.4.jar export CLASSPATH=`/usr/local/hadoop/bin/hdfs classpath --glob` ctest -j 8 --label-exclude cuda_driver --output-on-failure --no-tests=error @@ -163,7 +164,7 @@ jobs: run: | mkdir -p "$CCACHE_DIR" - - uses: actions/checkout@v4 + - uses: actions/checkout@v5.0.0 with: path: velox persist-credentials: false @@ -206,3 +207,63 @@ jobs: - name: Run Tests run: | cd _build/debug && ctest -j 8 --output-on-failure --no-tests=error + + fedora-debug: + runs-on: 8-core-ubuntu-22.04 + container: ghcr.io/facebookincubator/velox-dev:fedora + # prevent errors when forks ff their main branch + if: ${{ github.repository == 'facebookincubator/velox' }} + name: Fedora debug + env: + CCACHE_DIR: ${{ github.workspace }}/ccache + defaults: + run: + shell: bash + working-directory: velox + steps: + - name: Get Ccache Stash + uses: apache/infrastructure-actions/stash/restore@3354c1565d4b0e335b78a76aedd82153a9e144d4 + with: + path: ${{ env.CCACHE_DIR }} + key: ccache-fedora-debug-default-gcc + + - name: Ensure Stash Dirs Exists + working-directory: ${{ github.workspace }} + run: | + mkdir -p "$CCACHE_DIR" + + - uses: actions/checkout@v4 + with: + path: velox + persist-credentials: false + + - name: Clear CCache Statistics + run: | + ccache -sz + + - name: Make Debug Build + env: + VELOX_DEPENDENCY_SOURCE: SYSTEM + faiss_SOURCE: BUNDLED + fmt_SOURCE: BUNDLED + simdjson_SOURCE: BUNDLED + gRPC_SOURCE: SYSTEM + MAKEFLAGS: NUM_THREADS=4 MAX_HIGH_MEM_JOBS=4 MAX_LINK_JOBS=3 + EXTRA_CMAKE_FLAGS: >- + -DVELOX_ENABLE_PARQUET=ON + -DARROW_THRIFT_USE_SHARED=ON + -DVELOX_ENABLE_EXAMPLES=ON + run: | + uv tool install --force cmake@3.31.1 + dnf install -y -q --setopt=install_weak_deps=False grpc-devel grpc-plugins + export EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DVELOX_ENABLE_FAISS=ON" + make debug + + - name: CCache after + run: | + ccache -vs + + - uses: apache/infrastructure-actions/stash/save@3354c1565d4b0e335b78a76aedd82153a9e144d4 + with: + path: ${{ env.CCACHE_DIR }} + key: ccache-fedora-debug-default-gcc diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index f2db45237cbd..f195c2321e79 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -22,6 +22,9 @@ on: - CMakeLists.txt - CMake/** - scripts/setup-macos.sh + - scripts/setup-common.sh + - scripts/setup-versions.sh + - scripts/setup-helper-functions.sh - .github/workflows/macos.yml pull_request: @@ -31,6 +34,9 @@ on: - CMakeLists.txt - CMake/** - scripts/setup-macos.sh + - scripts/setup-common.sh + - scripts/setup-versions.sh + - scripts/setup-helper-functions.sh - .github/workflows/macos.yml permissions: @@ -58,7 +64,7 @@ jobs: INSTALL_PREFIX: /tmp/deps-install steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5.0.0 with: persist-credentials: false diff --git a/.github/workflows/preliminary_checks.yml b/.github/workflows/preliminary_checks.yml index fe3ce5af7108..9be23e9c0a5d 100644 --- a/.github/workflows/preliminary_checks.yml +++ b/.github/workflows/preliminary_checks.yml @@ -32,10 +32,10 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: persist-credentials: false - - uses: actions/setup-python@3542bca2639a428e1796aaa6a2ffef0c0f575566 # v3.1.4 + - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 title-check: @@ -49,7 +49,7 @@ jobs: import re import os title = os.environ["title"] - title_re = r"^(feat|fix|build|test|docs|refactor|misc)(\(.+\))?!?: ([A-Z].+)[^.]$" + title_re = r"^(feat|fix|perf|build|test|docs|refactor|misc)(\(.+\))?!?: ([A-Z].+)[^.]$" match = re.search(title_re, title) if match is None: diff --git a/.github/workflows/scheduled.yml b/.github/workflows/scheduled.yml index 17eb81ca190e..36af1082c1bf 100644 --- a/.github/workflows/scheduled.yml +++ b/.github/workflows/scheduled.yml @@ -172,7 +172,7 @@ jobs: - name: Checkout Main if: ${{ github.event_name != 'schedule' && steps.get-sig.outputs.stash-hit != 'true' }} - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: ref: ${{ steps.get-merge-base.outputs.head_main || 'main' }} path: velox_main @@ -204,7 +204,7 @@ jobs: key: function-signatures-${{ steps.get-merge-base.outputs.head_main }} - name: Checkout Contender - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -468,7 +468,7 @@ jobs: name: presto - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -728,7 +728,7 @@ jobs: name: join - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -829,7 +829,7 @@ jobs: name: row_number - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -891,7 +891,7 @@ jobs: name: topn_row_number - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -1057,7 +1057,7 @@ jobs: name: aggregation - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -1121,7 +1121,7 @@ jobs: name: presto - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -1210,7 +1210,7 @@ jobs: name: aggregation - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -1308,7 +1308,7 @@ jobs: name: window - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} @@ -1372,7 +1372,7 @@ jobs: name: writer - name: Checkout Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: path: velox ref: ${{ inputs.ref }} diff --git a/CMake/FindSnappy.cmake b/CMake/FindSnappy.cmake deleted file mode 100644 index 2d65b3d17666..000000000000 --- a/CMake/FindSnappy.cmake +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# - Try to find snappy -# Once done, this will define -# -# SNAPPY_FOUND - system has Glog -# SNAPPY_INCLUDE_DIRS - deprecated -# SNAPPY_LIBRARIES - deprecated -# Snappy::snappy will be defined based on CMAKE_FIND_LIBRARY_SUFFIXES priority - -include(FindPackageHandleStandardArgs) -include(SelectLibraryConfigurations) - -find_library(SNAPPY_LIBRARY_RELEASE snappy PATHS $SNAPPY_LIBRARYDIR}) -find_library(SNAPPY_LIBRARY_DEBUG snappyd PATHS ${SNAPPY_LIBRARYDIR}) - -find_path(SNAPPY_INCLUDE_DIR snappy.h PATHS ${SNAPPY_INCLUDEDIR}) - -select_library_configurations(SNAPPY) - -find_package_handle_standard_args(Snappy DEFAULT_MSG SNAPPY_LIBRARY SNAPPY_INCLUDE_DIR) - -mark_as_advanced(SNAPPY_LIBRARY SNAPPY_INCLUDE_DIR) - -get_filename_component(libsnappy_ext ${SNAPPY_LIBRARY} EXT) -if(libsnappy_ext STREQUAL ".a") - set(libsnappy_type STATIC) -else() - set(libsnappy_type SHARED) -endif() - -if(NOT TARGET Snappy::snappy) - add_library(Snappy::snappy ${libsnappy_type} IMPORTED) - set_target_properties( - Snappy::snappy - PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${SNAPPY_INCLUDE_DIR}" - ) - set_target_properties( - Snappy::snappy - PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" IMPORTED_LOCATION "${SNAPPY_LIBRARIES}" - ) -endif() diff --git a/CMake/VeloxUtils.cmake b/CMake/VeloxUtils.cmake index 8d80d374579a..efd964360e6c 100644 --- a/CMake/VeloxUtils.cmake +++ b/CMake/VeloxUtils.cmake @@ -92,7 +92,9 @@ function(velox_add_library TARGET) if(TARGET velox) # Target already exists, append sources to it. target_sources(velox PRIVATE ${ARGN}) - install(TARGETS velox LIBRARY DESTINATION pyvelox COMPONENT pyvelox_libraries) + if(VELOX_BUILD_PYTHON_PACKAGE) + install(TARGETS velox LIBRARY DESTINATION pyvelox COMPONENT pyvelox_libraries) + endif() else() set(_type STATIC) if(VELOX_BUILD_SHARED) diff --git a/CMake/resolve_dependency_modules/README.md b/CMake/resolve_dependency_modules/README.md index 84c07705a3d4..86da7262a8d3 100644 --- a/CMake/resolve_dependency_modules/README.md +++ b/CMake/resolve_dependency_modules/README.md @@ -30,7 +30,7 @@ by Velox. See details on bundling below. | xsimd | 10.0.0 | Yes | | re2 | 2024-07-02 | Yes | | fmt | 10.1.1 | Yes | -| simdjson | 3.9.3 | Yes | +| simdjson | 3.13.0 | Yes | | faiss | 1.11.0 | Yes | | folly | v2025.04.28.00 | Yes | | fizz | v2025.04.28.00 | No | diff --git a/CMake/resolve_dependency_modules/simdjson.cmake b/CMake/resolve_dependency_modules/simdjson.cmake index 962a76fadf15..0de2ea8c9e4c 100644 --- a/CMake/resolve_dependency_modules/simdjson.cmake +++ b/CMake/resolve_dependency_modules/simdjson.cmake @@ -13,10 +13,10 @@ # limitations under the License. include_guard(GLOBAL) -set(VELOX_SIMDJSON_VERSION 3.9.3) +set(VELOX_SIMDJSON_VERSION 3.13.0) set( VELOX_SIMDJSON_BUILD_SHA256_CHECKSUM - 2e3d10abcde543d3dd8eba9297522cafdcebdd1db4f51b28f3bc95bf1d6ad23c + 07a1bb3587aac18fd6a10a83fe4ab09f1100ab39f0cb73baea1317826b9f9e0d ) set( VELOX_SIMDJSON_SOURCE_URL diff --git a/CMakeLists.txt b/CMakeLists.txt index 2a5aac81cc22..8f1fe38f6796 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -445,7 +445,8 @@ if(ENABLE_ALL_WARNINGS) endif() if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "14.0.0") string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-error=template-id-cdtor") - string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-error=overloaded-virtual") + string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-overloaded-virtual") + string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-error=tautological-compare") endif() endif() @@ -544,9 +545,7 @@ if(NOT TARGET gflags::gflags) # target even when velox is built as a subproject which uses # `find_package(gflags)` which does not create a globally imported target that # we can ALIAS. - add_library(gflags_gflags INTERFACE) - target_link_libraries(gflags_gflags INTERFACE gflags) - add_library(gflags::gflags ALIAS gflags_gflags) + add_library(gflags::gflags ALIAS gflags) endif() if(${gflags_SOURCE} STREQUAL "BUNDLED") @@ -600,7 +599,7 @@ if(${VELOX_BUILD_MINIMAL_WITH_DWIO} OR ${VELOX_ENABLE_HIVE_CONNECTOR} OR VELOX_E endif() velox_set_source(simdjson) -velox_resolve_dependency(simdjson 3.9.3) +velox_resolve_dependency(simdjson 3.13.0) velox_set_source(folly) velox_resolve_dependency(folly) @@ -704,10 +703,6 @@ endif() include_directories(.) -# TODO: Include all other installation files. For now just making sure this -# generates an installable makefile. -install(FILES velox/type/Type.h DESTINATION "include/velox") - # Adding this down here prevents warnings in dependencies from stopping the # build if("${TREAT_WARNINGS_AS_ERRORS}") diff --git a/CODING_STYLE.md b/CODING_STYLE.md index 43947192cd74..97a60199941e 100644 --- a/CODING_STYLE.md +++ b/CODING_STYLE.md @@ -246,16 +246,17 @@ About comment style: * As a general rule, do not use string literals without declaring a named constant for them. * The best way to make a constant string literal is to use constexpr - `std::string_view`/`folly::StringPiece` + `std::string_view` * **NEVER** use `std::string` - this makes your code more prone to SIOF bugs. * Avoid `const char* const` and `const char*` - these are less efficient to convert to `std::string` later on in your program if you ever need to - because `std::string_view`/ `folly::StringPiece` knows its size and can use - a more efficient constructor. `std::string_view`/ `folly::StringPiece` also - has richer interfaces and often works as a drop-in replacement to - `std::string`. + because `std::string_view` knows its size and can use a more efficient + constructor. `std::string_view` also has richer interfaces and often + works as a drop-in replacement to `std::string`. * Need compile-time string concatenation? You can use `folly::FixedString` for that. + * Do not use `folly::StringPiece` in new code, use `std::string_view` + instead. ## Macros diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a76dd9e334cc..8db9559bddd3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -137,6 +137,7 @@ where: * *Type* can be any of the following keywords: * **feat** when new features are being added. * **fix** for bug fixes. + * **perf** for performance improvements. * **build** for build or CI-related improvements. * **test** for adding tests (only). * **docs** for enhancements to documentation (only). diff --git a/README.md b/README.md index e701cada2d7e..3fb83181fd9c 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Using the default install location `/usr/local` on macOS is discouraged since th location is used by certain Homebrew versions. Manually add the `INSTALL_PREFIX` value in the IDE or bash environment, -say `export INSTALL_PREFIX=/Users/$USERNAME/velox/deps-install` to `~/.zshrc` so that +say `export INSTALL_PREFIX=/Users/$USER/velox/deps-install` to `~/.zshrc` so that subsequent Velox builds can use the installed packages. *You can reuse `DEPENDENCY_INSTALL` and `INSTALL_PREFIX` for Velox clients such as Prestissimo diff --git a/scripts/ci/bm-report/report.qmd b/scripts/ci/bm-report/report.qmd index a91a116a6b93..b9ad14c79a44 100644 --- a/scripts/ci/bm-report/report.qmd +++ b/scripts/ci/bm-report/report.qmd @@ -50,7 +50,7 @@ run_shas <- runs |> jsonlite::fromJSON() run_ids <- mruns(run_shas) |> - filter(commit.branch == "facebookincubator:main", substr(id, 1, 2) == "BM") |> + filter(substr(id, 1, 2) == "BM") |> pull(id) # Speed up local dev by saving 'results' as conbench requests can't be memoised diff --git a/scripts/docker/fedora.dockerfile b/scripts/docker/fedora.dockerfile index d098c03f721d..28478c5857bd 100644 --- a/scripts/docker/fedora.dockerfile +++ b/scripts/docker/fedora.dockerfile @@ -64,6 +64,7 @@ ENV UV_TOOL_BIN_DIR=/usr/local/bin \ RUN /bin/bash -c 'source /setup-fedora.sh && \ install_build_prerequisites && \ install_velox_deps_from_dnf && \ + dnf_install jq gh &&\ dnf clean all' RUN ln -s $(which python3) /usr/bin/python diff --git a/scripts/setup-centos9.sh b/scripts/setup-centos9.sh index b1dd08b0d880..a54226fd87d4 100755 --- a/scripts/setup-centos9.sh +++ b/scripts/setup-centos9.sh @@ -60,7 +60,7 @@ function install_build_prerequisites { ninja-build python3-pip python3-devel wget which install_uv - uv_install cmake + uv_install cmake@3.31.1 if [[ ${USE_CLANG} != "false" ]]; then install_clang15 diff --git a/scripts/setup-common.sh b/scripts/setup-common.sh index 594616649425..7bc5d8866e78 100755 --- a/scripts/setup-common.sh +++ b/scripts/setup-common.sh @@ -125,6 +125,8 @@ function install_boost { } function install_protobuf { + install_abseil + wget_and_untar https://github.com/protocolbuffers/protobuf/releases/download/v"${PROTOBUF_VERSION}"/protobuf-all-"${PROTOBUF_VERSION}".tar.gz protobuf cmake_install_dir protobuf -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_ABSL_PROVIDER=package } @@ -139,9 +141,29 @@ function install_ranges_v3 { cmake_install_dir ranges_v3 -DRANGES_ENABLE_WERROR=OFF -DRANGE_V3_TESTS=OFF -DRANGE_V3_EXAMPLES=OFF } +function install_abseil { + wget_and_untar https://github.com/abseil/abseil-cpp/archive/refs/tags/"${ABSEIL_VERSION}".tar.gz abseil-cpp + local OS + OS=$(uname) + if [[ $OS == "Darwin" ]]; then + ABSOLUTE_SCRIPTDIR=$(realpath "$SCRIPT_DIR") + ( + cd "${DEPENDENCY_DIR}/abseil-cpp" || exit 1 + git apply $ABSOLUTE_SCRIPTDIR/../CMake/resolve_dependency_modules/absl/absl-macos.patch + ) + fi + cmake_install_dir abseil-cpp \ + -DABSL_BUILD_TESTING=OFF \ + -DCMAKE_CXX_STANDARD=17 \ + -DABSL_PROPAGATE_CXX_STD=ON \ + -DABSL_ENABLE_INSTALL=ON +} + function install_re2 { + install_abseil + wget_and_untar https://github.com/google/re2/archive/refs/tags/"${RE2_VERSION}".tar.gz re2 - cmake_install_dir re2 -DRE2_BUILD_TESTING=OFF + cmake_install_dir re2 -DRE2_BUILD_TESTING=OFF -Dabsl_DIR="${INSTALL_PREFIX}/lib/cmake/absl" } function install_glog { @@ -303,12 +325,7 @@ function install_gcs-sdk-cpp { # https://github.com/googleapis/google-cloud-cpp/blob/main/doc/packaging.md#required-libraries # abseil-cpp - github_checkout abseil/abseil-cpp "${ABSEIL_VERSION}" --depth 1 - cmake_install \ - -DABSL_BUILD_TESTING=OFF \ - -DCMAKE_CXX_STANDARD=17 \ - -DABSL_PROPAGATE_CXX_STD=ON \ - -DABSL_ENABLE_INSTALL=ON + install_abseil # protobuf github_checkout protocolbuffers/protobuf v"${PROTOBUF_VERSION}" --depth 1 @@ -353,7 +370,7 @@ function install_azure-storage-sdk-cpp { export AZURE_SDK_DISABLE_AUTO_VCPKG=ON vcpkg_commit_id=7a6f366cefd27210f6a8309aed10c31104436509 github_checkout azure/azure-sdk-for-cpp azure-storage-files-datalake_"${AZURE_SDK_VERSION}" - sed -i "s/set(VCPKG_COMMIT_STRING .*)/set(VCPKG_COMMIT_STRING $vcpkg_commit_id)/" cmake-modules/AzureVcpkg.cmake + sed -i='' "s/set(VCPKG_COMMIT_STRING .*)/set(VCPKG_COMMIT_STRING $vcpkg_commit_id)/" cmake-modules/AzureVcpkg.cmake azure_core_dir="sdk/core/azure-core" if ! grep -q "baseline" $azure_core_dir/vcpkg.json; then @@ -362,8 +379,8 @@ function install_azure-storage-sdk-cpp { if [[ $openssl_version == 1.1.1* ]]; then openssl_version="1.1.1n" fi - sed -i "s/\"version-string\"/\"builtin-baseline\": \"$vcpkg_commit_id\",\"version-string\"/" $azure_core_dir/vcpkg.json - sed -i "s/\"version-string\"/\"overrides\": [{ \"name\": \"openssl\", \"version-string\": \"$openssl_version\" }],\"version-string\"/" $azure_core_dir/vcpkg.json + sed -i='' "s/\"version-string\"/\"builtin-baseline\": \"$vcpkg_commit_id\",\"version-string\"/" $azure_core_dir/vcpkg.json + sed -i='' "s/\"version-string\"/\"overrides\": [{ \"name\": \"openssl\", \"version-string\": \"$openssl_version\" }],\"version-string\"/" $azure_core_dir/vcpkg.json fi ( cd $azure_core_dir || exit @@ -393,7 +410,7 @@ function install_azure-storage-sdk-cpp { function install_hdfs_deps { # Dependencies for Hadoop testing - wget_and_untar https://archive.apache.org/dist/hadoop/common/hadoop-"${HADOOP_VERSION}"/hadoop-"${HADOOP_VERSION}".tar.gz hadoop + wget_and_untar https://dlcdn.apache.org/hadoop/common/hadoop-"${HADOOP_VERSION}"/hadoop-"${HADOOP_VERSION}".tar.gz hadoop cp -a "${DEPENDENCY_DIR}"/hadoop "$INSTALL_PREFIX" wget "${WGET_OPTS[@]}" -P "$INSTALL_PREFIX"/hadoop/share/hadoop/common/lib/ https://repo1.maven.org/maven2/junit/junit/4.11/junit-4.11.jar } diff --git a/scripts/setup-fedora.sh b/scripts/setup-fedora.sh index 619be70ec892..89df33d3d495 100755 --- a/scripts/setup-fedora.sh +++ b/scripts/setup-fedora.sh @@ -55,6 +55,35 @@ function install_build_prerequisites { fi } +function install_velox_deps_from_dnf { + dnf_install \ + bison boost-devel c-ares-devel curl-devel double-conversion-devel \ + elfutils-libelf-devel flex fmt-devel gflags-devel glog-devel gmock-devel \ + gtest-devel libdwarf-devel libevent-devel libicu-devel \ + libsodium-devel libzstd-devel lz4-devel openssl-devel-engine \ + re2-devel snappy-devel thrift-devel xxhash-devel zlib-devel grpc-devel grpc-plugins + + install_faiss_deps +} + +function install_velox_deps { + run_and_time install_velox_deps_from_dnf + run_and_time install_gcs-sdk-cpp #grpc, abseil, protobuf + run_and_time install_fast_float + run_and_time install_folly + run_and_time install_fizz + run_and_time install_wangle + run_and_time install_mvfst + run_and_time install_fbthrift + run_and_time install_duckdb + run_and_time install_stemmer + run_and_time install_arrow + run_and_time install_xsimd # to new in fedora repos + run_and_time install_simdjson # to new in fedora repos + run_and_time install_geos # to new in fedora repos + run_and_time install_faiss +} + (return 2>/dev/null) && return # If script was sourced, don't run commands. ( @@ -90,6 +119,8 @@ function install_build_prerequisites { set -u fi install_velox_deps + # BUILD_TESTING requires grpc + dnf_install grpc echo "All dependencies for Velox installed!" if [[ ${USE_CLANG} != "false" ]]; then echo "To use clang for the Velox build set the CC and CXX environment variables in your session." diff --git a/scripts/setup-versions.sh b/scripts/setup-versions.sh index 43428a7f4d8e..5842a3c63db8 100755 --- a/scripts/setup-versions.sh +++ b/scripts/setup-versions.sh @@ -26,7 +26,7 @@ ARROW_VERSION="15.0.0" DUCKDB_VERSION="v0.8.1" PROTOBUF_VERSION="21.8" XSIMD_VERSION="10.0.0" -SIMDJSON_VERSION="3.9.3" +SIMDJSON_VERSION="3.13.0" CPR_VERSION="1.10.5" DOUBLE_CONVERSION_VERSION="v3.1.5" RANGE_V3_VERSION="0.12.0" @@ -49,7 +49,7 @@ GRPC_VERSION="v1.48.1" CRC32_VERSION="1.1.2" NLOHMAN_JSON_VERSION="v3.11.3" GOOGLE_CLOUD_CPP_VERSION="v2.22.0" -HADOOP_VERSION="3.3.0" +HADOOP_VERSION="3.3.6" AZURE_SDK_VERSION="12.8.0" MINIO_VERSION="2022-05-26T05-48-41Z" MINIO_BINARY_NAME="minio-2022-05-26" diff --git a/velox/buffer/Buffer.cpp b/velox/buffer/Buffer.cpp index 806959295401..abb7a3f09d63 100644 --- a/velox/buffer/Buffer.cpp +++ b/velox/buffer/Buffer.cpp @@ -18,9 +18,24 @@ namespace facebook::velox { +std::string Buffer::typeString(Type type) { + switch (type) { + case Type::kPOD: + return "kPOD"; + case Type::kNonPOD: + return "kNonPOD"; + case Type::kPODView: + return "kPODView"; + case Type::kNonPODView: + return "kNonPODView"; + default: + return fmt::format("Unknown({})", static_cast(type)); + } +} + namespace { struct BufferReleaser { - explicit BufferReleaser(const BufferPtr& parent) : parent_(parent) {} + explicit BufferReleaser(BufferPtr parent) : parent_{std::move(parent)} {} void addRef() const {} void release() const {} diff --git a/velox/buffer/Buffer.h b/velox/buffer/Buffer.h index 48cd0a629094..58c1f10f2db7 100644 --- a/velox/buffer/Buffer.h +++ b/velox/buffer/Buffer.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include "velox/common/base/BitUtil.h" #include "velox/common/base/CheckedArithmetic.h" #include "velox/common/base/Exceptions.h" @@ -26,8 +28,7 @@ #include "velox/common/base/SimdUtil.h" #include "velox/common/memory/Memory.h" -namespace facebook { -namespace velox { +namespace facebook::velox { class Buffer; class AlignedBuffer; @@ -58,21 +59,40 @@ class Buffer { // type. Thus the conditions are: trivial destructor (no resources to release) // and trivially copyable (so memcpy works) template - static inline constexpr bool is_pod_like_v = + static constexpr bool is_pod_like_v = std::is_trivially_destructible_v && std::is_trivially_copyable_v; - virtual ~Buffer() {} + virtual ~Buffer() = default; - void addRef() { - referenceCount_.fetch_add(1); + static constexpr uint8_t kPODBit = 0; + static constexpr uint8_t kPODMask = 1 << kPODBit; + static constexpr uint8_t kViewBit = 1; + static constexpr uint8_t kViewMask = 1 << kViewBit; + static_assert(kPODBit != kViewBit); + + enum class Type : uint8_t { + kNonPOD = 0 << kPODBit | 0 << kViewBit, + kPOD = 1 << kPODBit | 0 << kViewBit, + kNonPODView = 0 << kPODBit | 1 << kViewBit, + kPODView = 1 << kPODBit | 1 << kViewBit, + }; + + static std::string typeString(Type type); + + Type type() const { + return type_; + } + + void addRef() noexcept { + referenceCount_.fetch_add(1, std::memory_order_acq_rel); } - int refCount() const { - return referenceCount_; + int refCount() const noexcept { + return referenceCount_.load(std::memory_order_acquire); } void release() { - if (referenceCount_.fetch_sub(1) == 1) { + if (referenceCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { releaseResources(); if (pool_) { freeToPool(); @@ -86,13 +106,13 @@ class Buffer { const T* as() const { // We can't check actual types, but we can sanity-check POD/non-POD // conversion. `void` is special as it's used in type-erased contexts - VELOX_DCHECK((std::is_same_v) || podType_ == is_pod_like_v); + VELOX_DCHECK(std::is_void_v || isPOD() == is_pod_like_v); return reinterpret_cast(data_); } template Range asRange() { - return Range(as(), 0, size() / sizeof(T)); + return {as(), 0, static_cast(size() / sizeof(T))}; } template @@ -102,16 +122,16 @@ class Buffer { VELOX_CHECK(!isView()); // We can't check actual types, but we can sanity-check POD/non-POD // conversion. `void` is special as it's used in type-erased contexts - VELOX_DCHECK((std::is_same_v) || podType_ == is_pod_like_v); + VELOX_DCHECK(std::is_void_v || isPOD() == is_pod_like_v); return reinterpret_cast(data_); } template MutableRange asMutableRange() { - return MutableRange(asMutable(), 0, size() / sizeof(T)); + return {asMutable(), 0, static_cast(size() / sizeof(T))}; } - size_t size() const { + size_t size() const noexcept { return size_; } @@ -126,24 +146,28 @@ class Buffer { checkEndGuard(); } - uint64_t capacity() const { + uint64_t capacity() const noexcept { return capacity_; } - bool unique() const { - return referenceCount_ == 1; + bool unique() const noexcept { + return refCount() == 1; } - velox::memory::MemoryPool* pool() const { + velox::memory::MemoryPool* pool() const noexcept { return pool_; } - bool isMutable() const { + bool isMutable() const noexcept { return !isView() && unique(); } - virtual bool isView() const { - return false; + bool isView() const { + return (static_cast(type_) & kViewMask) != 0; + } + + bool isPOD() const { + return (static_cast(type_) & kPODMask) != 0; } friend std::ostream& operator<<(std::ostream& os, const Buffer& buffer) { @@ -199,6 +223,10 @@ class Buffer { sizeof(T), is_pod_like_v, buffer, offset, length); } + virtual bool transferTo(velox::memory::MemoryPool* /*pool*/) { + VELOX_NYI("{} unsupported", __FUNCTION__); + } + protected: // Writes a magic word at 'capacity_'. No-op for a BufferView. The actual // logic is inside a separate virtual function, allowing override by derived @@ -241,7 +269,7 @@ class Buffer { virtual void copyFrom(const Buffer* other, size_t bytes) { VELOX_CHECK(!isView()); VELOX_CHECK_GE(capacity_, bytes); - VELOX_CHECK(podType_); + VELOX_CHECK_EQ(type_, Type::kPOD); memcpy(data_, other->data_, bytes); } @@ -252,27 +280,24 @@ class Buffer { } Buffer( - velox::memory::MemoryPool* pool, + Type type, uint8_t* data, size_t capacity, - bool podType) - : pool_(pool), - data_(data), - capacity_(capacity), - referenceCount_(0), - podType_(podType) {} + velox::memory::MemoryPool* pool) + : pool_{pool}, data_{data}, capacity_{capacity}, type_{type} {} velox::memory::MemoryPool* const pool_; uint8_t* const data_; - uint64_t size_ = 0; - uint64_t capacity_ = 0; - std::atomic referenceCount_; - bool podType_ = true; - // Pad to 64 bytes. If using as int32_t[], guarantee that value at index -1 == - // -1. - uint64_t padding_[2] = {static_cast(-1), static_cast(-1)}; - // Needs to use setCapacity() from static method reallocate(). - friend class AlignedBuffer; + + uint64_t size_{0}; + uint64_t capacity_; + std::atomic_int32_t referenceCount_{0}; + + const Type type_; + + // Pad to 64 bytes. + // If using as int32_t[], guarantee that value at index -1 == -1. + uint64_t padding_[2]{static_cast(-1), static_cast(-1)}; private: static BufferPtr sliceBufferZeroCopy( @@ -281,6 +306,9 @@ class Buffer { const BufferPtr& buffer, size_t offset, size_t length); + + // Needs to use setCapacity() from static method reallocate(). + friend class AlignedBuffer; }; static_assert( @@ -289,12 +317,12 @@ static_assert( template <> inline Range Buffer::asRange() { - return Range(as(), 0, size() * 8); + return {as(), 0, static_cast(size() * 8)}; } template <> inline MutableRange Buffer::asMutableRange() { - return MutableRange(asMutable(), 0, size() * 8); + return {asMutable(), 0, static_cast(size() * 8)}; } template <> @@ -304,11 +332,11 @@ BufferPtr Buffer::slice( size_t length, memory::MemoryPool* pool); -static inline void intrusive_ptr_add_ref(Buffer* buffer) { +FOLLY_ALWAYS_INLINE void intrusive_ptr_add_ref(Buffer* buffer) noexcept { buffer->addRef(); } -static inline void intrusive_ptr_release(Buffer* buffer) { +FOLLY_ALWAYS_INLINE void intrusive_ptr_release(Buffer* buffer) noexcept { buffer->release(); } @@ -325,7 +353,7 @@ class AlignedBuffer : public Buffer { static constexpr int32_t kSizeofAlignedBuffer = 64; static constexpr int32_t kPaddedSize = kSizeofAlignedBuffer + simd::kPadding; - ~AlignedBuffer() { + ~AlignedBuffer() override { // This may throw, which is expected to signal an error to the // user. This is better for distributed debugging than killing the // process. In concept this indicates the possibility of memory @@ -337,10 +365,8 @@ class AlignedBuffer : public Buffer { // It's almost like partial specialization, but we redirect all POD types to // the same non-templated class template - using ImplClass = typename std::conditional< - is_pod_like_v, - AlignedBuffer, - NonPODAlignedBuffer>::type; + using ImplClass = std:: + conditional_t, AlignedBuffer, NonPODAlignedBuffer>; /** * Allocates enough memory to store numElements of type T. May @@ -368,7 +394,7 @@ class AlignedBuffer : public Buffer { void* memory = pool->allocate(preferredSize); VELOX_CHECK_NOT_NULL(memory); - auto* buffer = new (memory) ImplClass(pool, preferredSize - kPaddedSize); + auto* buffer = new (memory) ImplClass{pool, preferredSize - kPaddedSize}; // set size explicitly instead of setSize because `fillNewMemory` already // called the constructors buffer->size_ = size; @@ -377,6 +403,18 @@ class AlignedBuffer : public Buffer { return result; } + /// A verbose version of the allocate() with the exact size. + /// May allocate slightly more memory than strictly necessary. Guarantees that + /// simd::kPadding bytes past capacity() are addressable and asserts that + /// these do not get overrun. + template + static BufferPtr allocateExact( + size_t numElements, + velox::memory::MemoryPool* pool, + const std::optional& initValue = std::nullopt) { + return allocate(numElements, pool, initValue, true); + } + // Changes the capacity of '*buffer'. The buffer may grow/shrink in // place or may change addresses. The content is copied up to the // old size() or the new size, whichever is smaller. If the buffer grows, the @@ -418,34 +456,32 @@ class AlignedBuffer : public Buffer { // called the constructors newBuffer->size_ = size; *buffer = std::move(newBuffer); - return; - } - if (!old->unique()) { + } else if (!old->unique()) { auto newBuffer = allocate(numElements, pool); newBuffer->copyFrom(old, std::min(size, old->size())); reinterpret_cast(newBuffer.get()) ->template fillNewMemory(old->size(), size, initValue); newBuffer->size_ = size; *buffer = std::move(newBuffer); - return; - } - auto oldCapacity = checkedPlus(old->capacity(), kPaddedSize); - auto preferredSize = - pool->preferredSize(checkedPlus(size, kPaddedSize)); + } else { + auto oldCapacity = checkedPlus(old->capacity(), kPaddedSize); + auto preferredSize = + pool->preferredSize(checkedPlus(size, kPaddedSize)); - void* newPtr = pool->reallocate(old, oldCapacity, preferredSize); + void* newPtr = pool->reallocate(old, oldCapacity, preferredSize); - // Make the old buffer no longer owned by '*buffer' because reallocate - // freed the old buffer. Reassigning the new buffer to - // '*buffer' would be a double free if we didn't do this. - buffer->detach(); + // Make the old buffer no longer owned by '*buffer' because reallocate + // freed the old buffer. Reassigning the new buffer to + // '*buffer' would be a double free if we didn't do this. + buffer->detach(); - auto newBuffer = - new (newPtr) AlignedBuffer(pool, preferredSize - kPaddedSize); - newBuffer->setSize(size); - newBuffer->fillNewMemory(oldSize, size, initValue); + auto newBuffer = + new (newPtr) AlignedBuffer{pool, preferredSize - kPaddedSize}; + newBuffer->setSize(size); + newBuffer->fillNewMemory(oldSize, size, initValue); - *buffer = newBuffer; + *buffer = newBuffer; + } } // Appends bytes starting at 'items' for a length of 'sizeof(T) * @@ -480,7 +516,7 @@ class AlignedBuffer : public Buffer { } VELOX_CHECK( - bufferPtr->podType_, "Support for non POD types not implemented yet"); + bufferPtr->isPOD(), "Support for non POD types not implemented yet"); // The reason we use uint8_t is because mutableNulls()->size() will return // in byte count. We also don't bother initializing since copyFrom will be @@ -492,13 +528,49 @@ class AlignedBuffer : public Buffer { return newBuffer; } + template + static BufferPtr copy( + const BufferPtr& buffer, + velox::memory::MemoryPool* pool) { + if (buffer == nullptr) { + return nullptr; + } + + // The reason we use uint8_t is because mutableNulls()->size() will return + // in byte count. We also don't bother initializing since copyFrom will be + // overwriting anyway. + BufferPtr newBuffer; + if constexpr (std::is_same_v) { + newBuffer = AlignedBuffer::allocate(buffer->size(), pool); + } else { + const auto numElements = checkedDivide(buffer->size(), sizeof(T)); + newBuffer = AlignedBuffer::allocate(numElements, pool); + } + + newBuffer->copyFrom(buffer.get(), newBuffer->size()); + + return newBuffer; + } + + bool transferTo(velox::memory::MemoryPool* pool) override { + if (pool_ == pool) { + return true; + } + if (pool_->transferTo( + pool, this, checkedPlus(kPaddedSize, capacity_))) { + setPool(pool); + return true; + } + return false; + } + protected: AlignedBuffer(velox::memory::MemoryPool* pool, size_t capacity) - : Buffer( - pool, + : Buffer{ + Type::kPOD, reinterpret_cast(this) + sizeof(*this), capacity, - true /*podType*/) { + pool} { static_assert(sizeof(*this) == kAlignment); static_assert(sizeof(*this) == kSizeofAlignedBuffer); setEndGuard(); @@ -532,7 +604,12 @@ class AlignedBuffer : public Buffer { } } - protected: + void setPool(velox::memory::MemoryPool* pool) { + velox::memory::MemoryPool** poolPtr = + const_cast(&pool_); + *poolPtr = pool; + } + void setEndGuardImpl() override { *reinterpret_cast(data_ + capacity_) = kEndGuard; } @@ -597,13 +674,30 @@ class NonPODAlignedBuffer : public Buffer { } } + bool transferTo(velox::memory::MemoryPool* pool) override { + if (pool_ == pool) { + return true; + } + + if (pool_->transferTo( + pool, + this, + checkedPlus(AlignedBuffer::kPaddedSize, capacity_))) { + velox::memory::MemoryPool** poolPtr = + const_cast(&pool_); + *poolPtr = pool; + return true; + } + return false; + } + protected: NonPODAlignedBuffer(velox::memory::MemoryPool* pool, size_t capacity) - : Buffer( - pool, + : Buffer{ + Type::kNonPOD, reinterpret_cast(this) + sizeof(*this), capacity, - false /*podType*/) { + pool} { static_assert(sizeof(*this) == AlignedBuffer::kAlignment); static_assert(sizeof(*this) == sizeof(AlignedBuffer)); } @@ -611,8 +705,8 @@ class NonPODAlignedBuffer : public Buffer { void releaseResources() override { VELOX_CHECK_EQ(size_ % sizeof(T), 0); size_t numValues = size_ / sizeof(T); - // we can't use asMutable because it checks isMutable and we wan't to - // destroy regardless + // we can't use asMutable because it checks isMutable and we wan't + // to destroy regardless T* ptr = reinterpret_cast(data_); for (int i = 0; i < numValues; ++i) { ptr[i].~T(); @@ -620,6 +714,8 @@ class NonPODAlignedBuffer : public Buffer { } void copyFrom(const Buffer* other, size_t bytes) override { + // TODO: change this to isMutable(). See + // https://github.com/facebookincubator/velox/issues/6562. VELOX_CHECK(!isView()); VELOX_CHECK_GE(size_, bytes); VELOX_DCHECK( @@ -676,47 +772,60 @@ class NonPODAlignedBuffer : public Buffer { template class BufferView : public Buffer { public: - static BufferPtr create( - const uint8_t* data, - size_t size, - Releaser releaser, - bool podType = true) { - BufferView* view = new BufferView(data, size, releaser, podType); - BufferPtr result(view); + template + static BufferPtr + create(const uint8_t* data, size_t size, R&& releaser, bool podType = true) { + auto* view = new BufferView{data, size, std::forward(releaser), podType}; + BufferPtr result{view}; return result; } // Helper method to create a buffer view referencing another existing Buffer. + template static BufferPtr - create(BufferPtr innerBuffer, Releaser releaser, bool podType = true) { + create(const BufferPtr& innerBuffer, R&& releaser, bool podType = true) { return create( - innerBuffer->as(), innerBuffer->size(), releaser, podType); + innerBuffer->as(), + innerBuffer->size(), + std::forward(releaser), + podType); } ~BufferView() override { releaser_.release(); } - bool isView() const override { - return true; + bool transferTo(velox::memory::MemoryPool* pool) override { + if (pool_ == pool) { + return true; + } + return false; } private: - BufferView(const uint8_t* data, size_t size, Releaser releaser, bool podType) + template + BufferView(const uint8_t* data, size_t size, R&& releaser, bool podType) // A BufferView must be created over the data held by a cache // pin, which is typically const. The Buffer enforces const-ness // when returning the pointer. We cast away the const here to // avoid a separate code path for const and non-const Buffer // payloads. - : Buffer(nullptr, const_cast(data), size, podType), - releaser_(releaser) { + : Buffer{podType ? Type::kPODView : Type::kNonPODView, const_cast(data), size, nullptr}, + releaser_{std::forward(releaser)} { size_ = size; - capacity_ = size; releaser_.addRef(); } - Releaser const releaser_; + [[no_unique_address]] const Releaser releaser_; }; -} // namespace velox -} // namespace facebook +} // namespace facebook::velox + +// fmt formatter specialization for Buffer::Type +template <> +struct fmt::formatter : formatter { + auto format(facebook::velox::Buffer::Type s, format_context& ctx) const { + return formatter::format( + facebook::velox::Buffer::typeString(s), ctx); + } +}; diff --git a/velox/buffer/tests/BufferTest.cpp b/velox/buffer/tests/BufferTest.cpp index 9db4963221da..6bc6d1b51549 100644 --- a/velox/buffer/tests/BufferTest.cpp +++ b/velox/buffer/tests/BufferTest.cpp @@ -143,6 +143,20 @@ TEST_F(BufferTest, testAlignedBufferExact) { EXPECT_GE(buffer4->capacity(), oneMBMinusPad + 1); } +TEST_F(BufferTest, testAllocateExact) { + const int32_t oneMBMinusPad = 1024 * 1024 - AlignedBuffer::kPaddedSize; + + BufferPtr buffer1 = AlignedBuffer::allocateExact( + oneMBMinusPad + 1, pool_.get(), std::nullopt); + EXPECT_EQ(buffer1->size(), oneMBMinusPad + 1); + EXPECT_GE(buffer1->capacity(), oneMBMinusPad + 1); + + BufferPtr buffer2 = AlignedBuffer::allocateExact(3, pool_.get(), 'i'); + for (size_t i = 0; i < buffer2->size(); i++) { + EXPECT_EQ(buffer2->as()[i], 'i'); + } +} + TEST_F(BufferTest, testAsRange) { // Simple 2 element vector. std::vector testData({5, 255}); @@ -535,5 +549,34 @@ TEST_F(BufferTest, sliceBooleanBuffer) { Buffer::slice(bufferPtr, 5, 6, nullptr), "Pool must not be null."); } +TEST_F(BufferTest, testType) { + // Test AlignedBuffer type + auto alignedBuffer = AlignedBuffer::allocate(100, pool_.get()); + EXPECT_EQ(alignedBuffer->type(), Buffer::Type::kPOD); + EXPECT_TRUE(alignedBuffer->isPOD()); + EXPECT_FALSE(alignedBuffer->isView()); + + // Test NonPODAlignedBuffer type + auto nonPODBuffer = AlignedBuffer::allocate(10, pool_.get()); + EXPECT_EQ(nonPODBuffer->type(), Buffer::Type::kNonPOD); + EXPECT_FALSE(nonPODBuffer->isPOD()); + EXPECT_FALSE(nonPODBuffer->isView()); + + // Test BufferView type + MockCachePin pin; + const char* data = "test data"; + auto podBufferView = BufferView::create( + reinterpret_cast(data), 9, pin); + EXPECT_EQ(podBufferView->type(), Buffer::Type::kPODView); + EXPECT_TRUE(podBufferView->isPOD()); + EXPECT_TRUE(podBufferView->isView()); + + auto nonPodBufferView = BufferView::create( + reinterpret_cast(data), 9, pin, false); + EXPECT_EQ(nonPodBufferView->type(), Buffer::Type::kNonPODView); + EXPECT_FALSE(nonPodBufferView->isPOD()); + EXPECT_TRUE(nonPodBufferView->isView()); +} + } // namespace velox } // namespace facebook diff --git a/velox/common/CMakeLists.txt b/velox/common/CMakeLists.txt index 6661427654be..f6457da34760 100644 --- a/velox/common/CMakeLists.txt +++ b/velox/common/CMakeLists.txt @@ -18,6 +18,8 @@ add_subdirectory(config) add_subdirectory(dynamic_registry) add_subdirectory(encode) add_subdirectory(file) +add_subdirectory(future) +add_subdirectory(geospatial) add_subdirectory(hyperloglog) add_subdirectory(io) add_subdirectory(memory) @@ -26,3 +28,5 @@ add_subdirectory(serialization) add_subdirectory(time) add_subdirectory(testutil) add_subdirectory(fuzzer) + +velox_install_library_headers() diff --git a/velox/common/base/AdmissionController.cpp b/velox/common/base/AdmissionController.cpp index c7e1b71ea5ac..e1dae70a402e 100644 --- a/velox/common/base/AdmissionController.cpp +++ b/velox/common/base/AdmissionController.cpp @@ -32,12 +32,10 @@ void AdmissionController::accept(uint64_t resourceUnits) { { std::lock_guard l(mu_); if (unitsUsed_ + resourceUnits > config_.maxLimit) { - auto [unblockPromise, unblockFuture] = makeVeloxContinuePromiseContract(); Request req; req.unitsRequested = resourceUnits; - req.promise = std::move(unblockPromise); + future = req.promise.getSemiFuture(); queue_.push_back(std::move(req)); - future = std::move(unblockFuture); } else { updatedValue = unitsUsed_ += resourceUnits; } diff --git a/velox/common/base/CoalesceIo.h b/velox/common/base/CoalesceIo.h index f5f58ba1092b..7ad405cf2908 100644 --- a/velox/common/base/CoalesceIo.h +++ b/velox/common/base/CoalesceIo.h @@ -65,14 +65,13 @@ CoalesceIoStats coalesceIo( AddRanges addRanges, SkipRange skipRange, IoFunc ioFunc) { - std::vector buffers; int32_t startItem = 0; auto startOffset = offsetFunc(startItem); auto lastEndOffset = startOffset; std::vector ranges; CoalesceIoStats result; for (int32_t i = 0; i < items.size(); ++i) { - auto& item = items[i]; + const auto& item = items[i]; const auto itemOffset = offsetFunc(i); const auto itemSize = sizeFunc(i); result.payloadBytes += itemSize; diff --git a/velox/common/base/CountBits.h b/velox/common/base/CountBits.h index b267d2f636ef..f40fb95355e5 100644 --- a/velox/common/base/CountBits.h +++ b/velox/common/base/CountBits.h @@ -16,6 +16,8 @@ #pragma once +#include + namespace facebook::velox { // Copied from format.h of fmt. diff --git a/velox/common/base/Counters.cpp b/velox/common/base/Counters.cpp index d33536567892..c00803818f9c 100644 --- a/velox/common/base/Counters.cpp +++ b/velox/common/base/Counters.cpp @@ -107,8 +107,13 @@ void registerVeloxMetrics() { // was opened to load the cache. DEFINE_METRIC(kMetricCacheMaxAgeSecs, facebook::velox::StatType::AVG); - // Total number of cache entries. - DEFINE_METRIC(kMetricMemoryCacheNumEntries, facebook::velox::StatType::AVG); + // Total number of tiny cache entries. + DEFINE_METRIC( + kMetricMemoryCacheNumTinyEntries, facebook::velox::StatType::AVG); + + // Total number of large cache entries. + DEFINE_METRIC( + kMetricMemoryCacheNumLargeEntries, facebook::velox::StatType::AVG); // Total number of cache entries that do not cache anything. DEFINE_METRIC( diff --git a/velox/common/base/Counters.h b/velox/common/base/Counters.h index a936a8715467..916bc979c603 100644 --- a/velox/common/base/Counters.h +++ b/velox/common/base/Counters.h @@ -195,8 +195,11 @@ constexpr folly::StringPiece kMetricMmapAllocatorDelegatedAllocatedBytes{ constexpr folly::StringPiece kMetricCacheMaxAgeSecs{"velox.cache_max_age_secs"}; -constexpr folly::StringPiece kMetricMemoryCacheNumEntries{ - "velox.memory_cache_num_entries"}; +constexpr folly::StringPiece kMetricMemoryCacheNumTinyEntries{ + "velox.memory_cache_num_tiny_entries"}; + +constexpr folly::StringPiece kMetricMemoryCacheNumLargeEntries{ + "velox.memory_cache_num_large_entries"}; constexpr folly::StringPiece kMetricMemoryCacheNumEmptyEntries{ "velox.memory_cache_num_empty_entries"}; diff --git a/velox/common/base/PeriodicStatsReporter.cpp b/velox/common/base/PeriodicStatsReporter.cpp index 8afbabd59c3d..84f471b2eb51 100644 --- a/velox/common/base/PeriodicStatsReporter.cpp +++ b/velox/common/base/PeriodicStatsReporter.cpp @@ -139,7 +139,10 @@ void PeriodicStatsReporter::reportCacheStats() { const auto cacheStats = cache_->refreshStats(); // Memory cache snapshot stats. - RECORD_METRIC_VALUE(kMetricMemoryCacheNumEntries, cacheStats.numEntries); + RECORD_METRIC_VALUE( + kMetricMemoryCacheNumTinyEntries, cacheStats.numTinyEntries); + RECORD_METRIC_VALUE( + kMetricMemoryCacheNumLargeEntries, cacheStats.numLargeEntries); RECORD_METRIC_VALUE( kMetricMemoryCacheNumEmptyEntries, cacheStats.numEmptyEntries); RECORD_METRIC_VALUE(kMetricMemoryCacheNumSharedEntries, cacheStats.numShared); diff --git a/velox/common/base/Portability.h b/velox/common/base/Portability.h index 60049fcc54c7..85e890585e58 100644 --- a/velox/common/base/Portability.h +++ b/velox/common/base/Portability.h @@ -19,6 +19,7 @@ #include #include #include +#include #include inline size_t count_trailing_zeros(uint64_t x) { diff --git a/velox/common/base/SimdUtil.h b/velox/common/base/SimdUtil.h index 1aabc0f2952f..d853931776e3 100644 --- a/velox/common/base/SimdUtil.h +++ b/velox/common/base/SimdUtil.h @@ -430,6 +430,7 @@ uint32_t crc32U64(uint32_t checksum, uint64_t value, const A& arch = {}) { template xsimd::batch iota(const A& = {}); +#ifdef VELOX_ENABLE_LOAD_SIMD_VALUE_BUFFER // Returns a batch with all elements set to value. For batch we // use one bit to represent one element. template @@ -445,6 +446,7 @@ xsimd::batch setAll(T value, const A& = {}) { return xsimd::broadcast(value); } } +#endif // Stores 'data' into 'destination' for the lanes in 'mask'. 'mask' is expected // to specify contiguous lower lanes of 'batch'. For non-SIMD cases, 'mask' is diff --git a/velox/common/base/SpillConfig.cpp b/velox/common/base/SpillConfig.cpp index dd428a41ec7b..6308b37d58db 100644 --- a/velox/common/base/SpillConfig.cpp +++ b/velox/common/base/SpillConfig.cpp @@ -35,7 +35,8 @@ SpillConfig::SpillConfig( uint64_t _writerFlushThresholdSize, const std::string& _compressionKind, std::optional _prefixSortConfig, - const std::string& _fileCreateConfig) + const std::string& _fileCreateConfig, + uint32_t _windowMinReadBatchRows) : getSpillDirPathCb(std::move(_getSpillDirPathCb)), updateAndCheckSpillLimitCb(std::move(_updateAndCheckSpillLimitCb)), fileNamePrefix(std::move(_fileNamePrefix)), @@ -54,7 +55,8 @@ SpillConfig::SpillConfig( writerFlushThresholdSize(_writerFlushThresholdSize), compressionKind(common::stringToCompressionKind(_compressionKind)), prefixSortConfig(_prefixSortConfig), - fileCreateConfig(_fileCreateConfig) { + fileCreateConfig(_fileCreateConfig), + windowMinReadBatchRows(_windowMinReadBatchRows) { VELOX_USER_CHECK_GE( spillableReservationGrowthPct, minSpillableReservationPct, diff --git a/velox/common/base/SpillConfig.h b/velox/common/base/SpillConfig.h index 7f30bc6e614f..1fe13436ad5b 100644 --- a/velox/common/base/SpillConfig.h +++ b/velox/common/base/SpillConfig.h @@ -52,6 +52,13 @@ using GetSpillDirectoryPathCB = std::function; /// bytes exceed the set limit. using UpdateAndCheckSpillLimitCB = std::function; +/// Specifies the options for spill to disk. +struct SpillDiskOptions { + std::string spillDirPath; + bool spillDirCreated{true}; + std::function spillDirCreateCb{nullptr}; +}; + /// Specifies the config for spilling. struct SpillConfig { SpillConfig() = default; @@ -72,7 +79,8 @@ struct SpillConfig { uint64_t _writerFlushThresholdSize, const std::string& _compressionKind, std::optional _prefixSortConfig = std::nullopt, - const std::string& _fileCreateConfig = {}); + const std::string& _fileCreateConfig = {}, + uint32_t _windowMinReadBatchRows = 1'000); /// Returns the spilling level with given 'startBitOffset' and /// 'numPartitionBits'. @@ -157,5 +165,8 @@ struct SpillConfig { /// Custom options passed to velox::FileSystem to create spill WriteFile. std::string fileCreateConfig; + + /// The minimum number of rows to read when processing spilled window data. + uint32_t windowMinReadBatchRows; }; } // namespace facebook::velox::common diff --git a/velox/common/base/StatsReporter.h b/velox/common/base/StatsReporter.h index 94d0bd32c4f8..9df453a75db1 100644 --- a/velox/common/base/StatsReporter.h +++ b/velox/common/base/StatsReporter.h @@ -63,7 +63,7 @@ enum class StatType { HISTOGRAM, }; -inline std::string statTypeString(StatType stat) { +inline std::string_view statTypeString(StatType stat) { switch (stat) { case StatType::AVG: return "Avg"; @@ -76,7 +76,7 @@ inline std::string statTypeString(StatType stat) { case StatType::HISTOGRAM: return "Histogram"; default: - return fmt::format("UNKNOWN: {}", static_cast(stat)); + return "Unknown"; } } @@ -84,17 +84,17 @@ inline std::string statTypeString(StatType stat) { /// different implementations. class BaseStatsReporter { public: - virtual ~BaseStatsReporter() {} + virtual ~BaseStatsReporter() = default; /// Register a stat of the given stat type. /// @param key The key to identify the stat. /// @param statType How the stat is aggregated. virtual void registerMetricExportType(const char* key, StatType statType) - const = 0; + const {} virtual void registerMetricExportType( folly::StringPiece key, - StatType statType) const = 0; + StatType statType) const {} /// Register a histogram with a list of percentiles defined. /// @param key The key to identify the histogram. @@ -107,14 +107,14 @@ class BaseStatsReporter { int64_t bucketWidth, int64_t min, int64_t max, - const std::vector& pcts) const = 0; + const std::vector& pcts) const {} virtual void registerHistogramMetricExportType( folly::StringPiece key, int64_t bucketWidth, int64_t min, int64_t max, - const std::vector& pcts) const = 0; + const std::vector& pcts) const {} /// Register a quantile metric for quantile stats with export types, /// quantiles, and sliding window periods. @@ -127,13 +127,13 @@ class BaseStatsReporter { const char* key, const std::vector& statTypes, const std::vector& pcts, - const std::vector& slidingWindowsSeconds = {60}) const = 0; + const std::vector& slidingWindowsSeconds = {60}) const {} virtual void registerQuantileMetricExportType( folly::StringPiece key, const std::vector& statTypes, const std::vector& pcts, - const std::vector& slidingWindowsSeconds = {60}) const = 0; + const std::vector& slidingWindowsSeconds = {60}) const {} /// Register a dynamic quantile metric with a template key pattern that /// supports runtime substitution. @@ -145,60 +145,60 @@ class BaseStatsReporter { const char* keyPattern, const std::vector& statTypes, const std::vector& pcts, - const std::vector& slidingWindowsSeconds = {60}) const = 0; + const std::vector& slidingWindowsSeconds = {60}) const {} virtual void registerDynamicQuantileMetricExportType( folly::StringPiece keyPattern, const std::vector& statTypes, const std::vector& pcts, - const std::vector& slidingWindowsSeconds = {60}) const = 0; + const std::vector& slidingWindowsSeconds = {60}) const {} /// Add the given value to the stat. - virtual void addMetricValue(const std::string& key, size_t value = 1) - const = 0; + virtual void addMetricValue(const std::string& key, size_t value = 1) const {} - virtual void addMetricValue(const char* key, size_t value = 1) const = 0; + virtual void addMetricValue(const char* key, size_t value = 1) const {} - virtual void addMetricValue(folly::StringPiece key, size_t value = 1) - const = 0; + virtual void addMetricValue(folly::StringPiece key, size_t value = 1) const {} /// Add the given value to the histogram. virtual void addHistogramMetricValue(const std::string& key, size_t value) - const = 0; + const {} - virtual void addHistogramMetricValue(const char* key, size_t value) const = 0; + virtual void addHistogramMetricValue(const char* key, size_t value) const {} virtual void addHistogramMetricValue(folly::StringPiece key, size_t value) - const = 0; + const {} /// Add the given value to a quantile metric. virtual void addQuantileMetricValue(const std::string& key, size_t value = 1) - const = 0; + const {} - virtual void addQuantileMetricValue(const char* key, size_t value = 1) - const = 0; + virtual void addQuantileMetricValue(const char* key, size_t value = 1) const { + } virtual void addQuantileMetricValue(folly::StringPiece key, size_t value = 1) - const = 0; + const {} /// Add the given value to a quantile metric. virtual void addDynamicQuantileMetricValue( const std::string& key, folly::Range subkeys, - size_t value = 1) const = 0; + size_t value = 1) const {} virtual void addDynamicQuantileMetricValue( const char* key, folly::Range subkeys, - size_t value = 1) const = 0; + size_t value = 1) const {} virtual void addDynamicQuantileMetricValue( folly::StringPiece key, folly::Range subkeys, - size_t value = 1) const = 0; + size_t value = 1) const {} /// Return the aggregated metrics in a serialized string format. - virtual std::string fetchMetrics() = 0; + virtual std::string fetchMetrics() { + return ""; + } static bool registered; }; diff --git a/velox/common/base/Status.h b/velox/common/base/Status.h index 72ec5447618d..1cf42673bc82 100644 --- a/velox/common/base/Status.h +++ b/velox/common/base/Status.h @@ -530,6 +530,9 @@ void Status::moveFrom(Status& s) { #define _VELOX_RETURN_IMPL(expr, exprStr, error, ...) \ do { \ if (FOLLY_UNLIKELY(expr)) { \ + if (::facebook::velox::threadSkipErrorDetails()) { \ + return error(); \ + } \ auto message = ::facebook::velox::errorMessage(__VA_ARGS__); \ return error( \ ::facebook::velox::internal::generateError(message, exprStr)); \ diff --git a/velox/exec/TreeOfLosers.h b/velox/common/base/TreeOfLosers.h similarity index 100% rename from velox/exec/TreeOfLosers.h rename to velox/common/base/TreeOfLosers.h diff --git a/velox/common/base/tests/SimdUtilTest.cpp b/velox/common/base/tests/SimdUtilTest.cpp index 447cc55a6e2f..4b932d9a1b3e 100644 --- a/velox/common/base/tests/SimdUtilTest.cpp +++ b/velox/common/base/tests/SimdUtilTest.cpp @@ -126,6 +126,7 @@ class SimdUtilTest : public testing::Test { folly::Random::DefaultGenerator rng_; }; +#ifdef VELOX_ENABLE_LOAD_SIMD_VALUE_BUFFER TEST_F(SimdUtilTest, setAll) { auto bits = simd::setAll(true); auto words = reinterpret_cast(&bits); @@ -133,6 +134,7 @@ TEST_F(SimdUtilTest, setAll) { EXPECT_EQ(words[i], -1ll); } } +#endif TEST_F(SimdUtilTest, bitIndices) { testIndices(1); diff --git a/velox/common/base/tests/StatsReporterTest.cpp b/velox/common/base/tests/StatsReporterTest.cpp index 773c52db6c81..d21a870b46b7 100644 --- a/velox/common/base/tests/StatsReporterTest.cpp +++ b/velox/common/base/tests/StatsReporterTest.cpp @@ -28,7 +28,7 @@ #include "velox/common/caching/SsdCache.h" #include "velox/common/memory/MmapAllocator.h" -namespace facebook::velox { +namespace facebook::velox::test { struct QuantileConfig { std::vector statTypes; @@ -412,7 +412,8 @@ TEST_F(PeriodicStatsReporterTest, basic) { ASSERT_EQ(counterMap.count(kMetricArbitratorFreeCapacityBytes.str()), 1); ASSERT_EQ( counterMap.count(kMetricArbitratorFreeReservedCapacityBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEntries.str()), 1); + ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumTinyEntries.str()), 1); + ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumLargeEntries.str()), 1); ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEmptyEntries.str()), 1); ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumSharedEntries.str()), 1); ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumExclusiveEntries.str()), 1); @@ -471,7 +472,7 @@ TEST_F(PeriodicStatsReporterTest, basic) { ASSERT_EQ(counterMap.count(kMetricSsdCacheAgedOutRegions.str()), 0); ASSERT_EQ(counterMap.count(kMetricSsdCacheRecoveredEntries.str()), 0); ASSERT_EQ(counterMap.count(kMetricSsdCacheReadWithoutChecksum.str()), 0); - ASSERT_EQ(counterMap.size(), 23); + ASSERT_EQ(counterMap.size(), 24); } // Update stats @@ -554,7 +555,7 @@ TEST_F(PeriodicStatsReporterTest, basic) { ASSERT_EQ(counterMap.count(kMetricSsdCacheAgedOutRegions.str()), 1); ASSERT_EQ(counterMap.count(kMetricSsdCacheRecoveredEntries.str()), 1); ASSERT_EQ(counterMap.count(kMetricSsdCacheReadWithoutChecksum.str()), 1); - ASSERT_EQ(counterMap.size(), 55); + ASSERT_EQ(counterMap.size(), 56); } } @@ -903,7 +904,7 @@ folly::Singleton reporter([]() { return new TestReporter(); }); -} // namespace facebook::velox +} // namespace facebook::velox::test int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/velox/common/base/tests/StatsReporterUtils.h b/velox/common/base/tests/StatsReporterUtils.h index b65bee056e23..f58111c3e29a 100644 --- a/velox/common/base/tests/StatsReporterUtils.h +++ b/velox/common/base/tests/StatsReporterUtils.h @@ -26,13 +26,11 @@ #include #include "velox/common/base/StatsReporter.h" -namespace facebook::velox { +namespace facebook::velox::test { -/** - * A test implementation of BaseStatsReporter for use in unit tests. - * This class provides a mock implementation that captures all metric - * registrations and values for verification in tests. - */ +/// A test implementation of BaseStatsReporter for use in unit tests. +/// This class provides a mock implementation that captures all metric +/// registrations and values for verification in tests. class TestReporter : public BaseStatsReporter { public: mutable std::mutex m; @@ -239,10 +237,8 @@ class TestReporter : public BaseStatsReporter { return ss.str(); } - /** - * Get the current counter value for a specific key. - * Returns 0 if the key doesn't exist. - */ + // Get the current counter value for a specific key. + // Returns 0 if the key doesn't exist. size_t getCounterValue(const std::string& key) const { std::lock_guard l(m); auto it = counterMap.find(key); @@ -250,4 +246,4 @@ class TestReporter : public BaseStatsReporter { } }; -} // namespace facebook::velox +} // namespace facebook::velox::test diff --git a/velox/common/base/tests/StatusTest.cpp b/velox/common/base/tests/StatusTest.cpp index 282563d142f8..1501c18dd303 100644 --- a/velox/common/base/tests/StatusTest.cpp +++ b/velox/common/base/tests/StatusTest.cpp @@ -211,6 +211,22 @@ TEST(StatusTest, statusMacros) { "Reason: User error occurred.\nExpression: status != nullptr\n")); } +TEST(StatusTest, statusMacrosSkipDetails) { + ScopedThreadSkipErrorDetails skipErrorDetails(true); + ASSERT_EQ(returnMacroCheck(), Status::UserError()); + ASSERT_EQ(returnMacroEmptyMessage(), Status::UserError()); + ASSERT_EQ(returnMacroFormat(), Status::UserError()); + ASSERT_EQ(returnMacroGT(), Status::UserError()); + ASSERT_EQ(returnMacroGE(), Status::UserError()); + ASSERT_EQ(returnMacroLT(), Status::UserError()); + ASSERT_EQ(returnMacroLE(), Status::UserError()); + ASSERT_EQ(returnMacroEQ(), Status::UserError()); + ASSERT_EQ(returnMacroNE(), Status::UserError()); + ASSERT_EQ(returnMacroNULL(), Status::UserError()); + Status status = Status::OK(); + ASSERT_EQ(returnNotNull(&status), Status::UserError()); +} + Expected modulo(int a, int b) { if (b == 0) { return folly::makeUnexpected(Status::UserError("division by zero")); diff --git a/velox/common/caching/AsyncDataCache.cpp b/velox/common/caching/AsyncDataCache.cpp index 717107416d51..94fe2641296b 100644 --- a/velox/common/caching/AsyncDataCache.cpp +++ b/velox/common/caching/AsyncDataCache.cpp @@ -151,7 +151,7 @@ std::string AsyncDataCacheEntry::toString() const { numPins_); } -std::unique_ptr CacheShard::getFreeEntry() { +std::unique_ptr CacheShard::getFreeEntryLocked() { std::unique_ptr newEntry; if (freeEntries_.empty()) { newEntry = std::make_unique(this); @@ -176,7 +176,7 @@ CachePin CacheShard::findOrCreate( if (foundEntry->isExclusive()) { ++numWaitExclusive_; if (wait != nullptr) { - *wait = foundEntry->getFuture(); + *wait = foundEntry->getFutureLocked(); } return CachePin(); } @@ -212,7 +212,7 @@ CachePin CacheShard::findOrCreate( entryMap_.erase(it); } - auto newEntry = getFreeEntry(); + auto newEntry = getFreeEntryLocked(); // Initialize the members that must be set inside 'mutex_'. newEntry->numPins_ = AsyncDataCacheEntry::kExclusive; newEntry->promise_ = nullptr; @@ -336,7 +336,7 @@ std::unique_ptr> CacheShard::removeEntry( removeEntryLocked(entry); // After the entry is removed from the hash table, a promise can no longer // be made. It is safe to move the promise and realize it. - return entry->movePromise(); + return entry->movePromiseLocked(); } void CacheShard::removeEntryLocked(AsyncDataCacheEntry* entry) { @@ -409,7 +409,7 @@ uint64_t CacheShard::evict( eventCounter_ > entries_.size() / 4 || numChecked > entries_.size() / 8) { now = accessTime(); - calibrateThreshold(); + calibrateThresholdLocked(); numChecked = 0; eventCounter_ = 0; } @@ -490,7 +490,7 @@ void CacheShard::freeAllocations(std::vector& allocations) { allocations.clear(); } -void CacheShard::calibrateThreshold() { +void CacheShard::calibrateThresholdLocked() { auto numSamples = std::min(10, entries_.size()); auto now = accessTime(); auto entryIndex = (clockHand_ % entries_.size()); @@ -537,11 +537,15 @@ void CacheShard::updateStats(CacheStats& stats) { } ++stats.numEntries; - stats.tinySize += entry->tinyData_.size(); - stats.tinyPadding += entry->tinyData_.capacity() - entry->tinyData_.size(); if (entry->tinyData_.empty()) { stats.largeSize += entry->size_; stats.largePadding += entry->data_.byteSize() - entry->size_; + ++stats.numLargeEntries; + } else { + stats.tinySize += entry->tinyData_.size(); + stats.tinyPadding += + entry->tinyData_.capacity() - entry->tinyData_.size(); + ++stats.numTinyEntries; } } stats.numHit += numHit_; @@ -658,7 +662,7 @@ CacheStats CacheStats::operator-(const CacheStats& other) const { AsyncDataCache::AsyncDataCache( memory::MemoryAllocator* allocator, std::unique_ptr ssdCache) - : AsyncDataCache({}, allocator, std::move(ssdCache)){}; + : AsyncDataCache({}, allocator, std::move(ssdCache)) {} AsyncDataCache::AsyncDataCache( const Options& options, @@ -875,7 +879,7 @@ bool AsyncDataCache::canTryAllocate( return true; } return numPages - acquired.numPages() <= - (memory::AllocationTraits::numPages(allocator_->capacity())) - + memory::AllocationTraits::numPages(allocator_->capacity()) - allocator_->numAllocated(); } diff --git a/velox/common/caching/AsyncDataCache.h b/velox/common/caching/AsyncDataCache.h index 605e95cd5689..699e5b131d7c 100644 --- a/velox/common/caching/AsyncDataCache.h +++ b/velox/common/caching/AsyncDataCache.h @@ -263,22 +263,21 @@ class AsyncDataCacheEntry { /// Sets access stats so that this is immediately evictable. void makeEvictable(); - /// Moves the promise out of 'this'. Used in order to handle the - /// promise within the lock of the cache shard, so not within private - /// methods of 'this'. - std::unique_ptr> movePromise() { - return std::move(promise_); - } - std::string toString() const; private: void release(); void addReference(); + // Moves the promise out of 'this'. Must be called inside the mutex of + // 'shard_'. + std::unique_ptr> movePromiseLocked() { + return std::move(promise_); + } + // Returns a future that will be realized when a caller can retry getting // 'this'. Must be called inside the mutex of 'shard_'. - folly::SemiFuture getFuture() { + folly::SemiFuture getFutureLocked() { if (promise_ == nullptr) { promise_ = std::make_unique>(); } @@ -308,7 +307,7 @@ class AsyncDataCacheEntry { // True if 'this' is speculatively loaded. This is reset on first hit. Allows // catching a situation where prefetched entries get evicted before they are // hit. - bool isPrefetch_{false}; + tsan_atomic isPrefetch_{false}; // Sets after first use of a prefetched entry. Cleared by // getAndClearFirstUseFlag(). Does not require synchronization since used for @@ -496,6 +495,10 @@ struct CacheStats { int64_t largePadding{0}; /// Total number of entries. int32_t numEntries{0}; + /// Total number of tiny entries. + int32_t numTinyEntries{0}; + /// Total number of large entries. + int32_t numLargeEntries{0}; /// Number of entries that do not cache anything. int32_t numEmptyEntries{0}; /// Number of entries pinned for shared access. @@ -630,7 +633,7 @@ class CacheShard { static constexpr uint32_t kMaxFreeEntries = 1 << 10; static constexpr int32_t kNoThreshold = std::numeric_limits::max(); - void calibrateThreshold(); + void calibrateThresholdLocked(); void removeEntryLocked(AsyncDataCacheEntry* entry); @@ -638,7 +641,7 @@ class CacheShard { // // TODO: consider to pass a size hint so as to select the a free entry which // already has the right amount of memory associated with it. - std::unique_ptr getFreeEntry(); + std::unique_ptr getFreeEntryLocked(); CachePin initEntry(RawFileCacheKey key, AsyncDataCacheEntry* entry); @@ -876,9 +879,12 @@ class AsyncDataCache : public memory::Cache { void clear(); private: - static constexpr int32_t kNumShards = 4; // Must be power of 2. + // Must be power of 2. + static constexpr int32_t kNumShards = 4; static constexpr int32_t kShardMask = kNumShards - 1; + static_assert((kNumShards & kShardMask) == 0); + // True if 'acquired' has more pages than 'numPages' or allocator has space // for numPages - acquired pages of more allocation. bool canTryAllocate( diff --git a/velox/common/caching/ScanTracker.cpp b/velox/common/caching/ScanTracker.cpp index aef9fbcd593b..14fc6089f7ce 100644 --- a/velox/common/caching/ScanTracker.cpp +++ b/velox/common/caching/ScanTracker.cpp @@ -21,6 +21,20 @@ namespace facebook::velox::cache { +namespace { +template +void update(folly::ConcurrentHashMap& map, const K& key, U update) { + while (true) { + auto prev = map[key]; + auto updated = prev; + update(updated); + if (map.assign_if_equal(key, prev, updated)) { + break; + } + } +} +} // namespace + // Marks that 'bytes' worth of data may be accessed in the future. See // TrackingData for meaning of quantum. void ScanTracker::recordReference( @@ -31,11 +45,10 @@ void ScanTracker::recordReference( if (fileGroupStats_) { fileGroupStats_->recordReference(fileId, groupId, id, bytes); } - std::lock_guard l(mutex_); - auto& data = data_[id]; - data.referencedBytes += bytes; - data.lastReferencedBytes = bytes; - sum_.referencedBytes += bytes; + update(data_, id, [&](auto& value) { + value.referencedBytes += bytes; + value.lastReferencedBytes = bytes; + }); } void ScanTracker::recordRead( @@ -46,10 +59,7 @@ void ScanTracker::recordRead( if (fileGroupStats_) { fileGroupStats_->recordRead(fileId, groupId, id, bytes); } - std::lock_guard l(mutex_); - auto& data = data_[id]; - data.readBytes += bytes; - sum_.readBytes += bytes; + update(data_, id, [&](auto& value) { value.readBytes += bytes; }); } std::string ScanTracker::toString() const { diff --git a/velox/common/caching/ScanTracker.h b/velox/common/caching/ScanTracker.h index e4b8744c17ff..da62b6cd55e2 100644 --- a/velox/common/caching/ScanTracker.h +++ b/velox/common/caching/ScanTracker.h @@ -16,12 +16,8 @@ #pragma once -#include +#include #include -#include - -#include "velox/common/base/BitUtil.h" -#include "velox/common/base/Exceptions.h" namespace facebook::velox::cache { @@ -76,6 +72,12 @@ struct TrackingData { double referencedBytes{}; double lastReferencedBytes{}; double readBytes{}; + + bool operator==(const TrackingData& other) const { + return referencedBytes == other.referencedBytes && + lastReferencedBytes == other.lastReferencedBytes && + readBytes == other.readBytes; + } }; /// Tracks column access frequency during execution of a query. A ScanTracker is @@ -132,8 +134,7 @@ class ScanTracker { /// Returns the percentage of referenced columns that are actually read. 100% /// if no data. int32_t readPct(TrackingId id) { - std::lock_guard l(mutex_); - const auto& data = data_[id]; + const auto data = data_[id]; if (data.referencedBytes == 0) { return 100; } @@ -141,7 +142,6 @@ class ScanTracker { } TrackingData trackingData(TrackingId id) { - std::lock_guard l(mutex_); return data_[id]; } @@ -161,9 +161,7 @@ class ScanTracker { const std::function unregisterer_{nullptr}; FileGroupStats* const fileGroupStats_; - std::mutex mutex_; - folly::F14FastMap data_; - TrackingData sum_; + folly::ConcurrentHashMap data_; }; } // namespace facebook::velox::cache diff --git a/velox/common/caching/SsdCache.cpp b/velox/common/caching/SsdCache.cpp index 2347b45982fd..35a274fac8ea 100644 --- a/velox/common/caching/SsdCache.cpp +++ b/velox/common/caching/SsdCache.cpp @@ -90,7 +90,8 @@ bool SsdCache::startWrite() { } void SsdCache::write(std::vector pins) { - VELOX_CHECK_EQ(numShards_, writesInProgress_); + VELOX_CHECK_EQ( + numShards_, writesInProgress_, "startWrite() have not been called"); TestValue::adjust("facebook::velox::cache::SsdCache::write", this); @@ -98,7 +99,7 @@ void SsdCache::write(std::vector pins) { uint64_t bytes = 0; std::vector> shards(numShards_); - for (auto& pin : pins) { + for (const auto& pin : pins) { bytes += pin.checkedEntry()->size(); const auto& target = file(pin.checkedEntry()->key().fileNum.id()); shards[target.shardId()].push_back(std::move(pin)); @@ -135,9 +136,10 @@ void SsdCache::write(std::vector pins) { // Typically occurs every few GB. Allows detecting unusually slow rates // from failing devices. VELOX_SSD_CACHE_LOG(INFO) << fmt::format( - "Wrote {}, {} bytes/s", + "Wrote {} to SSD, {} bytes/s", succinctBytes(bytes), - static_cast(bytes) / (getCurrentTimeMicro() - startTimeUs)); + static_cast(bytes) * 1'000'000 / + (getCurrentTimeMicro() - startTimeUs)); } }); } diff --git a/velox/common/caching/SsdCache.h b/velox/common/caching/SsdCache.h index e46705949481..8e6630d9f86d 100644 --- a/velox/common/caching/SsdCache.h +++ b/velox/common/caching/SsdCache.h @@ -49,7 +49,7 @@ class SsdCache { disableFileCow(_disableFileCow), checksumEnabled(_checksumEnabled), checksumReadVerificationEnabled(_checksumReadVerificationEnabled), - executor(_executor){}; + executor(_executor) {} std::string filePrefix; uint64_t maxBytes; diff --git a/velox/common/caching/SsdFile.h b/velox/common/caching/SsdFile.h index c9ab1f5b3016..bed801923208 100644 --- a/velox/common/caching/SsdFile.h +++ b/velox/common/caching/SsdFile.h @@ -32,16 +32,18 @@ class SsdFileTestHelper; class SsdCacheTestHelper; } // namespace test -/// A 64 bit word describing a SSD cache entry in an SsdFile. The low 23 bits -/// are the size, for a maximum entry size of 8MB. The high bits are the offset. +/// The 'fileBits_' field is a 64 bit word describing a SSD cache entry in an +/// SsdFile. The low 23 bits are the size, for a maximum entry size of 8MB. The +/// high 41 bits are the offset. The 'checksum_' field is optional and is used +/// only when the checksum feature is enabled, otherwise, its value is always 0. class SsdRun { public: static constexpr int32_t kSizeBits = 23; - SsdRun() : fileBits_(0) {} + SsdRun() = default; SsdRun(uint64_t offset, uint32_t size, uint32_t checksum) - : fileBits_((offset << kSizeBits) | ((size - 1))), checksum_(checksum) { + : fileBits_((offset << kSizeBits) | (size - 1)), checksum_(checksum) { VELOX_CHECK_LT(offset, 1L << (64 - kSizeBits)); VELOX_CHECK_NE(size, 0); VELOX_CHECK_LE(size, 1 << kSizeBits); @@ -58,9 +60,11 @@ class SsdRun { checksum_ = other.checksum_; } - void operator=(SsdRun&& other) { + void operator=(SsdRun&& other) noexcept { fileBits_ = other.fileBits_; checksum_ = other.checksum_; + other.fileBits_ = 0; + other.checksum_ = 0; } uint64_t offset() const { @@ -83,8 +87,8 @@ class SsdRun { private: // Contains the file offset and size. - uint64_t fileBits_; - uint32_t checksum_; + uint64_t fileBits_{0}; + uint32_t checksum_{0}; }; /// Represents an SsdFile entry that is planned for load or being loaded. This @@ -266,7 +270,7 @@ class SsdFile { checksumEnabled(_checksumEnabled), checksumReadVerificationEnabled( _checksumEnabled && _checksumReadVerificationEnabled), - executor(_executor){}; + executor(_executor) {} /// Name of cache file, used as prefix for checkpoint files. const std::string fileName; diff --git a/velox/common/caching/StringIdMap.cpp b/velox/common/caching/StringIdMap.cpp index c8c88542da1e..e991c5a750ea 100644 --- a/velox/common/caching/StringIdMap.cpp +++ b/velox/common/caching/StringIdMap.cpp @@ -31,8 +31,8 @@ void StringIdMap::release(uint64_t id) { std::lock_guard l(mutex_); auto it = idToEntry_.find(id); if (it != idToEntry_.end()) { - VELOX_CHECK_LT( - 0, it->second.numInUse, "Extra release of id in StringIdMap"); + VELOX_CHECK_GT( + it->second.numInUse, 0, "Extra release of id in StringIdMap"); if (--it->second.numInUse == 0) { pinnedSize_ -= it->second.string.size(); auto strIter = stringToId_.find(it->second.string); @@ -60,11 +60,11 @@ uint64_t StringIdMap::makeId(std::string_view string) { if (it != stringToId_.end()) { auto entry = idToEntry_.find(it->second); VELOX_CHECK(entry != idToEntry_.end()); - if (++entry->second.numInUse == 1) { - pinnedSize_ += entry->second.string.size(); - } + VELOX_CHECK_GE(entry->second.numInUse, 1); + ++entry->second.numInUse; return it->second; } + Entry entry; entry.string = string; // Check that we do not use an id twice. In practice this never @@ -91,9 +91,8 @@ uint64_t StringIdMap::recoverId(uint64_t id, std::string_view string) { id, it->second, "Multiple recover ids assigned to {}", string); auto entry = idToEntry_.find(it->second); VELOX_CHECK(entry != idToEntry_.end()); - if (++entry->second.numInUse == 1) { - pinnedSize_ += entry->second.string.size(); - } + VELOX_CHECK_GE(entry->second.numInUse, 1); + ++entry->second.numInUse; return id; } diff --git a/velox/common/caching/tests/StringIdMapTest.cpp b/velox/common/caching/tests/StringIdMapTest.cpp index 59d95af2e88a..1a1d006748aa 100644 --- a/velox/common/caching/tests/StringIdMapTest.cpp +++ b/velox/common/caching/tests/StringIdMapTest.cpp @@ -22,7 +22,7 @@ using namespace facebook::velox; TEST(StringIdMapTest, basic) { - constexpr const char* kFile1 = "file_1"; + constexpr std::string_view kFile1 = "file_1"; StringIdMap map; uint64_t id = 0; { @@ -33,7 +33,7 @@ TEST(StringIdMapTest, basic) { id = lease2.id(); lease1 = lease2; EXPECT_EQ(id, lease1.id()); - EXPECT_EQ(strlen(kFile1), map.pinnedSize()); + EXPECT_EQ(kFile1.size(), map.pinnedSize()); } StringIdLease lease3(map, kFile1); EXPECT_NE(lease3.id(), id); @@ -56,50 +56,48 @@ TEST(StringIdMapTest, rehash) { } TEST(StringIdMapTest, recover) { - constexpr const char* kRecoverFile1 = "file_1"; - constexpr const char* kRecoverFile2 = "file_2"; - constexpr const char* kRecoverFile3 = "file_3"; + constexpr std::string_view kRecoverFile1("file_1"); + constexpr std::string_view kRecoverFile2("file_2"); + constexpr std::string_view kRecoverFile3("file_3"); StringIdMap map; const uint64_t recoverId1{10}; const uint64_t recoverId2{20}; { StringIdLease lease(map, recoverId1, kRecoverFile1); ASSERT_TRUE(lease.hasValue()); - ASSERT_EQ(map.pinnedSize(), ::strlen(kRecoverFile1)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size()); ASSERT_EQ(map.testingLastId(), recoverId1); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId1, kRecoverFile2), + StringIdLease(map, recoverId1, kRecoverFile2), "(1 vs. 0) Reused recover id 10 assigned to file_2"); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId2, kRecoverFile1), + StringIdLease(map, recoverId2, kRecoverFile1), "(20 vs. 10) Multiple recover ids assigned to file_1"); } ASSERT_EQ(map.pinnedSize(), 0); StringIdLease lease1(map, kRecoverFile1); - ASSERT_EQ(map.pinnedSize(), ::strlen(kRecoverFile1)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size()); ASSERT_EQ(map.testingLastId(), recoverId1 + 1); { StringIdLease lease(map, recoverId2, kRecoverFile2); ASSERT_TRUE(lease.hasValue()); ASSERT_EQ(lease.id(), recoverId2); - ASSERT_EQ( - map.pinnedSize(), ::strlen(kRecoverFile1) + ::strlen(kRecoverFile2)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size() + kRecoverFile2.size()); ASSERT_EQ(map.testingLastId(), recoverId2); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId2, kRecoverFile3), + StringIdLease(map, recoverId2, kRecoverFile3), "(1 vs. 0) Reused recover id 20 assigned to file_3"); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId2, kRecoverFile1), + StringIdLease(map, recoverId2, kRecoverFile1), "(20 vs. 11) Multiple recover ids assigned to file_1"); StringIdLease dupLease(map, recoverId2, kRecoverFile2); ASSERT_TRUE(lease.hasValue()); ASSERT_EQ(lease.id(), recoverId2); - ASSERT_EQ( - map.pinnedSize(), ::strlen(kRecoverFile1) + ::strlen(kRecoverFile2)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size() + kRecoverFile2.size()); } ASSERT_EQ(map.testingLastId(), recoverId2); - ASSERT_EQ(map.pinnedSize(), ::strlen(kRecoverFile1)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size()); } diff --git a/velox/common/config/Config.cpp b/velox/common/config/Config.cpp index 9a37fa65a08a..987dbdb3dc05 100644 --- a/velox/common/config/Config.cpp +++ b/velox/common/config/Config.cpp @@ -138,13 +138,12 @@ std::unordered_map ConfigBase::rawConfigsCopy() return configs_; } -std::optional ConfigBase::get(const std::string& key) const { - std::optional val; - std::shared_lock l(mutex_); - auto it = configs_.find(key); - if (it != configs_.end()) { - val = it->second; +std::optional ConfigBase::access(const std::string& key) const { + std::shared_lock l{mutex_}; + if (auto it = configs_.find(key); it != configs_.end()) { + return it->second; } - return val; + return std::nullopt; } + } // namespace facebook::velox::config diff --git a/velox/common/config/Config.h b/velox/common/config/Config.h index 7aea6575a03a..16031818d9b2 100644 --- a/velox/common/config/Config.h +++ b/velox/common/config/Config.h @@ -23,6 +23,7 @@ #include "folly/Conv.h" #include "velox/common/base/Exceptions.h" +#include "velox/common/config/IConfig.h" namespace facebook::velox::config { @@ -47,7 +48,7 @@ std::chrono::duration toDuration(const std::string& str); /// The concrete config class should inherit the config base and define all the /// entries. -class ConfigBase { +class ConfigBase : public IConfig { public: template struct Entry { @@ -111,49 +112,20 @@ class ConfigBase { : entry.defaultVal; } - template - std::optional get( - const std::string& key, - std::function toT = [](auto /* unused */, - auto value) { - return folly::to(value); - }) const { - auto val = get(key); - if (val.has_value()) { - return toT(key, val.value()); - } else { - return std::nullopt; - } - } - - template - T get( - const std::string& key, - const T& defaultValue, - std::function toT = [](auto /* unused */, - auto value) { - return folly::to(value); - }) const { - auto val = get(key); - if (val.has_value()) { - return toT(key, val.value()); - } else { - return defaultValue; - } - } + using IConfig::get; bool valueExists(const std::string& key) const; const std::unordered_map& rawConfigs() const; - std::unordered_map rawConfigsCopy() const; + std::unordered_map rawConfigsCopy() const final; protected: mutable std::shared_mutex mutex_; std::unordered_map configs_; private: - std::optional get(const std::string& key) const; + std::optional access(const std::string& key) const final; const bool mutable_; }; diff --git a/velox/common/config/IConfig.h b/velox/common/config/IConfig.h new file mode 100644 index 000000000000..06d8d8bc3f5e --- /dev/null +++ b/velox/common/config/IConfig.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace facebook::velox::config { + +/// IConfig - Read-only config interface +/// for accessing key-value parameters. +/// Supports value retrieval by key and +/// duplication of the raw configuration data. +/// Can be used by velox::QueryConfig to access +/// externally managed system configuration. +class IConfig { + public: + template + std::optional get( + const std::string& key, + const std::function& toT = + [](auto /* unused */, auto value) { + return folly::to(value); + }) const { + if (auto val = access(key)) { + return toT(key, *val); + } + return std::nullopt; + } + + template + T get( + const std::string& key, + const T& defaultValue, + const std::function& toT = + [](auto /* unused */, auto value) { + return folly::to(value); + }) const { + if (auto val = access(key)) { + return toT(key, *val); + } + return defaultValue; + } + + virtual std::unordered_map rawConfigsCopy() + const = 0; + + virtual ~IConfig() = default; + + private: + virtual std::optional access(const std::string& key) const = 0; +}; + +} // namespace facebook::velox::config diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index aa521a57d47f..11f7c1c0ed2d 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -187,14 +187,13 @@ size_t Base64::calculateEncodedSize(size_t inputSize, bool withPadding) { // static void Base64::encode(const char* input, size_t inputSize, char* output) { - encodeImpl( - folly::StringPiece(input, inputSize), kBase64Charset, true, output); + encodeImpl(std::string_view(input, inputSize), kBase64Charset, true, output); } // static void Base64::encodeUrl(const char* input, size_t inputSize, char* output) { encodeImpl( - folly::StringPiece(input, inputSize), kBase64UrlCharset, true, output); + std::string_view(input, inputSize), kBase64UrlCharset, true, output); } // static @@ -249,13 +248,13 @@ void Base64::encodeImpl( } // static -std::string Base64::encode(folly::StringPiece text) { +std::string Base64::encode(std::string_view text) { return encodeImpl(text, kBase64Charset, true); } // static std::string Base64::encode(const char* input, size_t inputSize) { - return encode(folly::StringPiece(input, inputSize)); + return encode(std::string_view(input, inputSize)); } namespace { @@ -308,7 +307,7 @@ std::string Base64::encode(const folly::IOBuf* inputBuffer) { } // static -std::string Base64::decode(folly::StringPiece encodedText) { +std::string Base64::decode(std::string_view encodedText) { std::string decodedResult; decode(std::make_pair(encodedText.data(), encodedText.size()), decodedResult); return decodedResult; @@ -492,13 +491,13 @@ Expected Base64::decodeImpl( } // static -std::string Base64::encodeUrl(folly::StringPiece text) { +std::string Base64::encodeUrl(std::string_view text) { return encodeImpl(text, kBase64UrlCharset, false); } // static std::string Base64::encodeUrl(const char* input, size_t inputSize) { - return encodeUrl(folly::StringPiece(input, inputSize)); + return encodeUrl(std::string_view(input, inputSize)); } // static @@ -521,7 +520,7 @@ Status Base64::decodeUrl( } // static -std::string Base64::decodeUrl(folly::StringPiece encodedText) { +std::string Base64::decodeUrl(std::string_view encodedText) { std::string decodedOutput; decodeUrl( std::make_pair(encodedText.data(), encodedText.size()), decodedOutput); diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 073cc49cd4f3..7dca7d2fdbce 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -16,7 +16,6 @@ #pragma once -#include #include #include @@ -45,7 +44,7 @@ class Base64 { static std::string encode(const char* input, size_t inputSize); /// Encodes the specified text. - static std::string encode(folly::StringPiece text); + static std::string encode(std::string_view text); /// Encodes the specified IOBuf data. static std::string encode(const folly::IOBuf* inputBuffer); @@ -60,7 +59,7 @@ class Base64 { static std::string encodeUrl(const char* input, size_t inputSize); /// Encodes the specified text using URL encoding. - static std::string encodeUrl(folly::StringPiece text); + static std::string encodeUrl(std::string_view text); /// Encodes the specified IOBuf data using URL encoding. static std::string encodeUrl(const folly::IOBuf* inputBuffer); @@ -72,7 +71,7 @@ class Base64 { encodeUrl(const char* input, size_t inputSize, char* outputBuffer); /// Decodes the input Base64 encoded string. - static std::string decode(folly::StringPiece encodedText); + static std::string decode(std::string_view encodedText); /// Decodes the specified encoded payload and writes the result to the /// 'output'. @@ -94,7 +93,7 @@ class Base64 { size_t outputSize); /// Decodes the input Base64 URL encoded string. - static std::string decodeUrl(folly::StringPiece encodedText); + static std::string decodeUrl(std::string_view encodedText); /// Decodes the specified URL encoded payload and writes the result to the /// 'output'. diff --git a/velox/common/encode/ByteStream.h b/velox/common/encode/ByteStream.h index 70bd25025138..9ebbee1054d7 100644 --- a/velox/common/encode/ByteStream.h +++ b/velox/common/encode/ByteStream.h @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #pragma once /** @@ -26,7 +27,6 @@ #include #include -#include namespace facebook::velox::strings { @@ -86,28 +86,28 @@ class ByteSink { * append() will return 0). In particular, this may not be used for * non-blocking behavior. */ - virtual size_t append(folly::StringPiece str) = 0; + virtual size_t append(std::string_view str) = 0; size_t append(const void* data, size_t size) { - return append(folly::StringPiece(static_cast(data), size)); + return append(std::string_view(static_cast(data), size)); } /** - * Append the given string to this ByteSink. The string must remain + * Append the given string to this ByteSink. The string must remain * allocated (and unchanged) until the ByteSink is destroyed. */ - virtual size_t appendAllocated(folly::StringPiece str) { + virtual size_t appendAllocated(std::string_view str) { return append(str); } /** * Convenience function that appends the bitwise representation of count - * objects starting at address obj. The usual caveats about endianness, + * objects starting at address obj. The usual caveats about endianness, * padding apply. */ template size_t appendBitwise(const T* obj, size_t count) { const size_t sz = count * sizeof(T); - return append(folly::StringPiece(reinterpret_cast(obj), sz)); + return append(std::string_view(reinterpret_cast(obj), sz)); } /** @@ -185,8 +185,8 @@ class SByteSink : public ByteSink { public: explicit SByteSink(S* str) : str_(str) {} - size_t append(folly::StringPiece s) override { - str_->append(s.start(), s.size()); + size_t append(std::string_view s) override { + str_->append(s.data(), s.size()); return s.size(); } @@ -237,7 +237,7 @@ class ByteSource { * next() will return false, but bad() will also return false. On error, * next() returns false, and bad() returns true. */ - virtual bool next(folly::StringPiece* chunk) = 0; + virtual bool next(std::string_view* chunk) = 0; /** * Push back the last numBytes returned by the last next() call, so @@ -316,7 +316,7 @@ class ByteSourceBuffer : public std::basic_streambuf { class StringByteSource : public ByteSource { public: explicit StringByteSource( - const folly::StringPiece& str, + const std::string_view& str, size_t maxBytes = kSizeMax) : str_(str), offset_(0), @@ -326,15 +326,17 @@ class StringByteSource : public ByteSource { bool bad() const override { return false; } - bool next(folly::StringPiece* chunk) override { + + bool next(std::string_view* chunk) override { if (offset_ == str_.size()) { return false; } size_t len = std::min(str_.size() - offset_, maxBytes_); - chunk->reset(str_.start() + offset_, len); + *chunk = std::string_view(str_.data() + offset_, len); offset_ += len; return true; } + void backUp(size_t numBytes) override { CHECK_LE(numBytes, maxBytes_); CHECK_GE(offset_, numBytes); @@ -342,7 +344,7 @@ class StringByteSource : public ByteSource { } private: - folly::StringPiece str_; + std::string_view str_; size_t offset_; size_t maxBytes_; }; diff --git a/velox/common/encode/Coding.h b/velox/common/encode/Coding.h index 55e66d3ce915..e4edaac4bcde 100644 --- a/velox/common/encode/Coding.h +++ b/velox/common/encode/Coding.h @@ -109,20 +109,20 @@ class Varint { char buf[kMaxSize64]; char* p = buf; encode(val, &p); - sink->append(folly::StringPiece(buf, p - buf)); + sink->append(std::string_view(buf, p - buf)); } static void encode128ToByteSink(UInt128 val, strings::ByteSink* sink) { char buf[kMaxSize128]; char* p = buf; encode128(val, &p); - sink->append(folly::StringPiece(buf, p - buf)); + sink->append(std::string_view(buf, p - buf)); } // Returns true if decode can be called without causing a CHECK failure. // The pointers are not adjusted at all - static bool canDecode(folly::StringPiece src) { - src = src.subpiece(0, kMaxSize64); + static bool canDecode(std::string_view src) { + src = src.substr(0, kMaxSize64); return std::any_of( src.begin(), src.end(), [](char v) { return ~v & 0x80; }); } @@ -187,18 +187,18 @@ class Varint { return val; } - // Decode a value from a StringPiece, and advance the StringPiece. - static uint64_t decode(folly::StringPiece* data) { - const char* p = data->start(); + // Decode a value from a string_view, and advance it. + static uint64_t decode(std::string_view* data) { + const char* p = data->data(); uint64_t val = decode(&p, data->size()); - data->advance(p - data->start()); + data->remove_prefix(p - data->data()); return val; } - static UInt128 decode128(folly::StringPiece* data) { - const char* p = data->start(); + static UInt128 decode128(std::string_view* data) { + const char* p = data->data(); UInt128 val = decode128(&p, data->size()); - data->advance(p - data->start()); + data->remove_prefix(p - data->data()); return val; } @@ -207,13 +207,13 @@ class Varint { uint64_t val = 0; int32_t shift = 0; int32_t max_size = kMaxSize64; - folly::StringPiece chunk; + std::string_view chunk; int32_t remaining = 0; const char* p = nullptr; for (;;) { if (remaining == 0) { CHECK(src->next(&chunk)); - p = chunk.start(); + p = chunk.data(); remaining = chunk.size(); DCHECK_GT(remaining, 0); } @@ -238,13 +238,13 @@ class Varint { UInt128 val = 0; int32_t shift = 0; int32_t max_size = kMaxSize128; - folly::StringPiece chunk; + std::string_view chunk; int32_t remaining = 0; const char* p = nullptr; for (;;) { if (remaining == 0) { CHECK(src->next(&chunk)); - p = chunk.start(); + p = chunk.data(); remaining = chunk.size(); DCHECK_GT(remaining, 0); } @@ -292,7 +292,7 @@ namespace detail { class ByteSinkAppender { public: /* implicit */ ByteSinkAppender(strings::ByteSink* out) : out_(out) {} - void operator()(folly::StringPiece sp) { + void operator()(std::string_view sp) { out_->append(sp.data(), sp.size()); } diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 91b0fcab9087..bc78bd6c5019 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -25,27 +25,20 @@ namespace facebook::velox::encoding { class Base64Test : public ::testing::Test {}; TEST_F(Base64Test, fromBase64) { - EXPECT_EQ( - "Hello, World!", - Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ=="))); + EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ==")); EXPECT_EQ( "Base64 encoding is fun.", - Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4="))); - EXPECT_EQ( - "Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ="))); - EXPECT_EQ( - "1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA=="))); + Base64::decode("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=")); + EXPECT_EQ("Simple text", Base64::decode("U2ltcGxlIHRleHQ=")); + EXPECT_EQ("1234567890", Base64::decode("MTIzNDU2Nzg5MA==")); // Check encoded strings without padding - EXPECT_EQ( - "Hello, World!", - Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ"))); + EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ")); EXPECT_EQ( "Base64 encoding is fun.", - Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4"))); - EXPECT_EQ( - "Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ"))); - EXPECT_EQ("1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA"))); + Base64::decode("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4")); + EXPECT_EQ("Simple text", Base64::decode("U2ltcGxlIHRleHQ")); + EXPECT_EQ("1234567890", Base64::decode("MTIzNDU2Nzg5MA")); } TEST_F(Base64Test, calculateDecodedSizeProperSize) { diff --git a/velox/common/file/File.cpp b/velox/common/file/File.cpp index c1f1e8982737..668cba32cede 100644 --- a/velox/common/file/File.cpp +++ b/velox/common/file/File.cpp @@ -60,10 +60,10 @@ T getAttribute( std::string ReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { std::string buf; buf.resize(length); - auto res = pread(offset, length, buf.data(), stats); + auto res = pread(offset, length, buf.data(), fileStorageContext); buf.resize(res.size()); return buf; } @@ -71,7 +71,7 @@ std::string ReadFile::pread( uint64_t ReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { auto fileSize = size(); uint64_t numRead = 0; if (offset >= fileSize) { @@ -81,7 +81,7 @@ uint64_t ReadFile::preadv( auto copySize = std::min(range.size(), fileSize - offset); // NOTE: skip the gap in case of coalesce io. if (range.data() != nullptr) { - pread(offset, copySize, range.data(), stats); + pread(offset, copySize, range.data(), fileStorageContext); } offset += copySize; numRead += copySize; @@ -92,18 +92,21 @@ uint64_t ReadFile::preadv( uint64_t ReadFile::preadv( folly::Range regions, folly::Range iobufs, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { VELOX_CHECK_EQ(regions.size(), iobufs.size()); uint64_t length = 0; for (size_t i = 0; i < regions.size(); ++i) { const auto& region = regions[i]; auto& output = iobufs[i]; output = folly::IOBuf(folly::IOBuf::CREATE, region.length); - pread(region.offset, region.length, output.writableData(), stats); + pread( + region.offset, + region.length, + output.writableData(), + fileStorageContext); output.append(region.length); length += region.length; } - return length; } @@ -111,7 +114,7 @@ std::string_view InMemoryReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { bytesRead_ += length; memcpy(buf, file_.data() + offset, length); return {static_cast(buf), length}; @@ -120,7 +123,7 @@ std::string_view InMemoryReadFile::pread( std::string InMemoryReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { bytesRead_ += length; return std::string(file_.data() + offset, length); } @@ -202,7 +205,7 @@ std::string_view LocalReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { preadInternal(offset, length, static_cast(buf)); return {static_cast(buf), length}; } @@ -210,7 +213,7 @@ std::string_view LocalReadFile::pread( uint64_t LocalReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { // Dropped bytes sized so that a typical dropped range of 50K is not // too many iovecs. static thread_local std::vector droppedBytes(16 * 1024); @@ -267,17 +270,18 @@ uint64_t LocalReadFile::preadv( folly::SemiFuture LocalReadFile::preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { if (!executor_) { - return ReadFile::preadvAsync(offset, buffers, stats); + return ReadFile::preadvAsync(offset, buffers, fileStorageContext); } auto [promise, future] = folly::makePromiseContract(); executor_->add([this, _promise = std::move(promise), _offset = offset, _buffers = buffers, - _stats = stats]() mutable { - auto delegateFuture = ReadFile::preadvAsync(_offset, _buffers, _stats); + _fileStorageContext = fileStorageContext]() mutable { + auto delegateFuture = + ReadFile::preadvAsync(_offset, _buffers, _fileStorageContext); _promise.setTry(std::move(delegateFuture).getTry()); }); return std::move(future); diff --git a/velox/common/file/File.h b/velox/common/file/File.h index 18d1c264ca7a..6089915dffac 100644 --- a/velox/common/file/File.h +++ b/velox/common/file/File.h @@ -37,14 +37,32 @@ #include #include #include +#include #include "velox/common/base/Exceptions.h" #include "velox/common/file/FileSystems.h" #include "velox/common/file/Region.h" #include "velox/common/io/IoStatistics.h" +#include + namespace facebook::velox { +struct FileStorageContext { + /// Stats for IO operations + filesystems::File::IoStats* ioStats{nullptr}; + + /// Options for file read operations + folly::F14FastMap fileReadOps; + + FileStorageContext() = default; + + FileStorageContext( + filesystems::File::IoStats* stats, + folly::F14FastMap fileReadOps = {}) + : ioStats(stats), fileReadOps(std::move(fileReadOps)) {} +}; + // A read-only file. All methods in this object should be thread safe. class ReadFile { public: @@ -52,16 +70,12 @@ class ReadFile { // Reads the data at [offset, offset + length) into the provided pre-allocated // buffer 'buf'. The bytes are returned as a string_view pointing to 'buf'. - // - // 'stats' is an IoStatistics pointer passed in by the caller to collect stats - // for this read operation. - // // This method should be thread safe. virtual std::string_view pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const = 0; + const FileStorageContext& fileStorageContext = {}) const = 0; // Same as above, but returns owned data directly. // @@ -69,20 +83,16 @@ class ReadFile { virtual std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats = nullptr) const; + const FileStorageContext& fileStorageContext = {}) const; // Reads starting at 'offset' into the memory referenced by the // Ranges in 'buffers'. The buffers are filled left to right. A // buffer with nullptr data will cause its size worth of bytes to be skipped. - // - // 'stats' is an IoStatistics pointer passed in by the caller to collect stats - // for this read operation. - // // This method should be thread safe. virtual uint64_t preadv( uint64_t /*offset*/, const std::vector>& /*buffers*/, - filesystems::File::IoStats* stats = nullptr) const; + const FileStorageContext& fileStorageContext = {}) const; // Vectorized read API. Implementations can coalesce and parallelize. // The offsets don't need to be sorted. @@ -93,30 +103,23 @@ class ReadFile { // by the preadv. // Returns the total number of bytes read, which might be different than the // sum of all buffer sizes (for example, if coalescing was used). - // - // 'stats' is an IoStatistics pointer passed in by the caller to collect stats - // for this read operation. - // // This method should be thread safe. virtual uint64_t preadv( folly::Range regions, folly::Range iobufs, - filesystems::File::IoStats* stats = nullptr) const; + const FileStorageContext& fileStorageContext = {}) const; /// Like preadv but may execute asynchronously and returns the read size or /// exception via SemiFuture. Use hasPreadvAsync() to check if the /// implementation is in fact asynchronous. - /// - /// 'stats' is an IoStatistics pointer passed in by the caller to collect - /// stats for this read operation. - /// /// This method should be thread safe. virtual folly::SemiFuture preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const { + const FileStorageContext& fileStorageContext = {}) const { try { - return folly::SemiFuture(preadv(offset, buffers, stats)); + return folly::SemiFuture( + preadv(offset, buffers, fileStorageContext)); } catch (const std::exception& e) { return folly::makeSemiFuture(e); } @@ -240,12 +243,12 @@ class InMemoryReadFile : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const override; + const FileStorageContext& fileStorageContext = {}) const override; uint64_t size() const final { return file_.size(); @@ -311,19 +314,19 @@ class LocalReadFile final : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t size() const final; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; folly::SemiFuture preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; bool hasPreadvAsync() const override { return executor_ != nullptr; diff --git a/velox/common/file/FileSystems.h b/velox/common/file/FileSystems.h index 337478be63bf..f4865c0ce6bc 100644 --- a/velox/common/file/FileSystems.h +++ b/velox/common/file/FileSystems.h @@ -88,6 +88,10 @@ struct FileOptions { /// A token provider that can be used to get tokens for accessing the file. std::shared_ptr tokenProvider{nullptr}; + + /// File read operations metadata that can be passed to the underlying file + /// system for tracking and logging purposes. + folly::F14FastMap fileReadOps{}; }; /// Defines directory options diff --git a/velox/common/file/tests/FaultyFile.cpp b/velox/common/file/tests/FaultyFile.cpp index 17897fa99214..434028a6750d 100644 --- a/velox/common/file/tests/FaultyFile.cpp +++ b/velox/common/file/tests/FaultyFile.cpp @@ -34,7 +34,7 @@ std::string_view FaultyReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { if (injectionHook_ != nullptr) { FaultFileReadOperation op(path_, offset, length, buf); injectionHook_(&op); @@ -42,13 +42,13 @@ std::string_view FaultyReadFile::pread( return std::string_view(static_cast(op.buf), op.length); } } - return delegatedFile_->pread(offset, length, buf, stats); + return delegatedFile_->pread(offset, length, buf, fileStorageContext); } uint64_t FaultyReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { if (injectionHook_ != nullptr) { FaultFileReadvOperation op(path_, offset, buffers); injectionHook_(&op); @@ -56,16 +56,16 @@ uint64_t FaultyReadFile::preadv( return op.readBytes; } } - return delegatedFile_->preadv(offset, buffers, stats); + return delegatedFile_->preadv(offset, buffers, fileStorageContext); } folly::SemiFuture FaultyReadFile::preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { // TODO: add fault injection for async read later. if (delegatedFile_->hasPreadvAsync() || executor_ == nullptr) { - return delegatedFile_->preadvAsync(offset, buffers, stats); + return delegatedFile_->preadvAsync(offset, buffers, fileStorageContext); } auto promise = std::make_unique>(); folly::SemiFuture future = promise->getSemiFuture(); @@ -73,9 +73,9 @@ folly::SemiFuture FaultyReadFile::preadvAsync( _promise = std::move(promise), _offset = offset, _buffers = buffers, - _stats = stats]() { + _fileStorageContext = fileStorageContext]() { auto delegateFuture = - delegatedFile_->preadvAsync(_offset, _buffers, _stats); + delegatedFile_->preadvAsync(_offset, _buffers, _fileStorageContext); _promise->setValue(delegateFuture.wait().value()); }); return future; diff --git a/velox/common/file/tests/FaultyFile.h b/velox/common/file/tests/FaultyFile.h index 2b4818bd7a16..8fac903b0f6e 100644 --- a/velox/common/file/tests/FaultyFile.h +++ b/velox/common/file/tests/FaultyFile.h @@ -29,7 +29,7 @@ class FaultyReadFile : public ReadFile { FileFaultInjectionHook injectionHook, folly::Executor* executor); - ~FaultyReadFile() override{}; + ~FaultyReadFile() override {} uint64_t size() const override { return delegatedFile_->size(); @@ -39,12 +39,12 @@ class FaultyReadFile : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; uint64_t memoryUsage() const override { return delegatedFile_->memoryUsage(); @@ -72,7 +72,7 @@ class FaultyReadFile : public ReadFile { folly::SemiFuture preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; private: const std::string path_; @@ -88,7 +88,7 @@ class FaultyWriteFile : public WriteFile { std::shared_ptr delegatedFile, FileFaultInjectionHook injectionHook); - ~FaultyWriteFile() override{}; + ~FaultyWriteFile() override {} void append(std::string_view data) override; diff --git a/velox/common/future/CMakeLists.txt b/velox/common/future/CMakeLists.txt new file mode 100644 index 000000000000..a598690b32e5 --- /dev/null +++ b/velox/common/future/CMakeLists.txt @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +velox_install_library_headers() diff --git a/velox/common/future/VeloxPromise.h b/velox/common/future/VeloxPromise.h index ba8fbdfbce3b..4aa39488aa01 100644 --- a/velox/common/future/VeloxPromise.h +++ b/velox/common/future/VeloxPromise.h @@ -65,6 +65,12 @@ using ContinuePromise = VeloxPromise; using ContinueFuture = folly::SemiFuture; /// Equivalent of folly's makePromiseContract for VeloxPromise. +/// +/// NOTE: When you already have a valid promise, just call +/// Promise::getSemiFuture() on it to get the future, instead of using this +/// function to overwrite the promise. Overwriting valid promise would cause +/// exception throwing and stack unwinding thus performance issue. See +/// https://github.com/prestodb/presto/issues/26094 for details. static inline std::pair makeVeloxContinuePromiseContract(const std::string& promiseContext = "") { auto p = ContinuePromise(promiseContext); diff --git a/velox/common/geospatial/CMakeLists.txt b/velox/common/geospatial/CMakeLists.txt new file mode 100644 index 000000000000..a598690b32e5 --- /dev/null +++ b/velox/common/geospatial/CMakeLists.txt @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +velox_install_library_headers() diff --git a/velox/common/geospatial/GeometryConstants.h b/velox/common/geospatial/GeometryConstants.h new file mode 100644 index 000000000000..56462ab2689b --- /dev/null +++ b/velox/common/geospatial/GeometryConstants.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +// This file contains constats for working with geospatial queries. +// They _must not_ require the GEOS library (or any 3p library). + +namespace facebook::velox::common::geospatial { + +enum class GeometrySerializationType : uint8_t { + POINT = 0, + MULTI_POINT = 1, + LINE_STRING = 2, + MULTI_LINE_STRING = 3, + POLYGON = 4, + MULTI_POLYGON = 5, + GEOMETRY_COLLECTION = 6, + ENVELOPE = 7 +}; + +enum class EsriShapeType : uint32_t { + POINT = 1, + POLYLINE = 3, + POLYGON = 5, + MULTI_POINT = 8 +}; + +} // namespace facebook::velox::common::geospatial diff --git a/velox/common/hyperloglog/DenseHll.cpp b/velox/common/hyperloglog/DenseHll.cpp index f3a3f54b2d2d..2245a3f69096 100644 --- a/velox/common/hyperloglog/DenseHll.cpp +++ b/velox/common/hyperloglog/DenseHll.cpp @@ -15,13 +15,13 @@ */ #include "velox/common/hyperloglog/DenseHll.h" -#include -#include +#include "velox/common/base/BitUtil.h" #include "velox/common/base/IOUtils.h" #include "velox/common/hyperloglog/BiasCorrection.h" #include "velox/common/hyperloglog/HllUtils.h" namespace facebook::velox::common::hll { + namespace { const int kBitsPerBucket = 4; const int8_t kMaxDelta = (1 << kBitsPerBucket) - 1; @@ -119,14 +119,26 @@ double correctBias(double rawEstimate, int8_t indexBitLength) { } } // namespace -DenseHll::DenseHll(int8_t indexBitLength, HashStringAllocator* allocator) - : deltas_{StlAllocator(allocator)}, - overflowBuckets_{StlAllocator(allocator)}, - overflowValues_{StlAllocator(allocator)} { +template +DenseHll::DenseHll(int8_t indexBitLength, TAllocator* allocator) + : allocator_(allocator), + deltas_{TStlAllocator(allocator)}, + overflowBuckets_{TStlAllocator(allocator)}, + overflowValues_{TStlAllocator(allocator)} { initialize(indexBitLength); } -void DenseHll::initialize(int8_t indexBitLength) { +template +DenseHll::DenseHll(TAllocator* allocator) + : indexBitLength_(-1), + baselineCount_(0), + allocator_(allocator), + deltas_{TStlAllocator(allocator)}, + overflowBuckets_{TStlAllocator(allocator)}, + overflowValues_{TStlAllocator(allocator)} {} + +template +void DenseHll::initialize(int8_t indexBitLength) { VELOX_CHECK_GE(indexBitLength, 4, "indexBitLength must be in [4, 16] range"); VELOX_CHECK_LE(indexBitLength, 16, "indexBitLength must be in [4, 16] range"); @@ -137,13 +149,15 @@ void DenseHll::initialize(int8_t indexBitLength) { deltas_.resize(numBuckets * kBitsPerBucket / 8); } -void DenseHll::insertHash(uint64_t hash) { +template +void DenseHll::insertHash(uint64_t hash) { auto index = computeIndex(hash, indexBitLength_); auto value = numberOfLeadingZeros(hash, indexBitLength_) + 1; insert(index, value); } -void DenseHll::insert(int32_t index, int8_t value) { +template +void DenseHll::insert(int32_t index, int8_t value) { auto delta = value - baseline_; auto oldDelta = getDelta(index); @@ -261,7 +275,8 @@ DenseHllView deserialize(const char* serialized) { } } // namespace -int64_t DenseHll::cardinality() const { +template +int64_t DenseHll::cardinality() const { DenseHllView hll{ indexBitLength_, baseline_, @@ -272,18 +287,14 @@ int64_t DenseHll::cardinality() const { return cardinalityImpl(hll); } -// static -int64_t DenseHll::cardinality(const char* serialized) { - auto hll = deserialize(serialized); - return cardinalityImpl(hll); -} - -int8_t DenseHll::getDelta(int32_t index) const { +template +int8_t DenseHll::getDelta(int32_t index) const { int slot = index >> 1; return (deltas_[slot] >> shiftForBucket(index)) & kBucketMask; } -void DenseHll::setDelta(int32_t index, int8_t value) { +template +void DenseHll::setDelta(int32_t index, int8_t value) { int slot = index >> 1; // Clear the old value. @@ -295,12 +306,14 @@ void DenseHll::setDelta(int32_t index, int8_t value) { deltas_[slot] |= setMask; } -int8_t DenseHll::getOverflow(int32_t index) const { +template +int8_t DenseHll::getOverflow(int32_t index) const { return getOverflowImpl( index, overflows_, overflowBuckets_.data(), overflowValues_.data()); } -int DenseHll::findOverflowEntry(int32_t index) const { +template +int DenseHll::findOverflowEntry(int32_t index) const { for (auto i = 0; i < overflows_; i++) { if (overflowBuckets_[i] == index) { return i; @@ -309,7 +322,8 @@ int DenseHll::findOverflowEntry(int32_t index) const { return -1; } -void DenseHll::adjustBaselineIfNeeded() { +template +void DenseHll::adjustBaselineIfNeeded() { auto numBuckets = 1 << indexBitLength_; while (baselineCount_ == 0) { @@ -359,7 +373,8 @@ void DenseHll::adjustBaselineIfNeeded() { } } -void DenseHll::sortOverflows() { +template +void DenseHll::sortOverflows() { // traditional insertion sort (ok for small arrays) for (int i = 1; i < overflows_; i++) { auto bucket = overflowBuckets_[i]; @@ -385,7 +400,8 @@ void DenseHll::sortOverflows() { } } -int32_t DenseHll::serializedSize() const { +template +int32_t DenseHll::serializedSize() const { return 1 /* type + version */ + 1 /* indexBitLength */ + 1 /* baseline */ @@ -395,13 +411,17 @@ int32_t DenseHll::serializedSize() const { + overflows_ /* overflow bucket values */; } -// static -bool DenseHll::canDeserialize(const char* input) { +int64_t DenseHlls::cardinality(const char* serialized) { + auto hll = deserialize(serialized); + return cardinalityImpl(hll); +} + +bool DenseHlls::canDeserialize(const char* input) { return *reinterpret_cast(input) == kPrestoDenseV2; } // static -bool DenseHll::canDeserialize(const char* input, int size) { +bool DenseHlls::canDeserialize(const char* input, int size) { if (size < 5) { // Min serialized sparse HLL size is 5 bytes. return false; @@ -459,22 +479,23 @@ bool DenseHll::canDeserialize(const char* input, int size) { return true; } -// static -int8_t DenseHll::deserializeIndexBitLength(const char* input) { +int8_t DenseHlls::deserializeIndexBitLength(const char* input) { common::InputByteStream stream(input); stream.read(); return stream.read(); } -// static -int32_t DenseHll::estimateInMemorySize(int8_t indexBitLength) { +int32_t DenseHlls::estimateInMemorySize(int8_t indexBitLength) { // Note: we don't take into account overflow entries since their number can // vary. - return sizeof(indexBitLength_) + sizeof(baseline_) + sizeof(baselineCount_) + + // return sizeof(indexBitLength_) + sizeof(baseline_) + + // sizeof(baselineCount_) + (1 << indexBitLength) / 2; + return sizeof(int8_t) + sizeof(int8_t) + sizeof(int32_t) + (1 << indexBitLength) / 2; } -void DenseHll::serialize(char* output) { +template +void DenseHll::serialize(char* output) { // sort overflow arrays to get consistent serialization for equivalent HLLs sortOverflows(); @@ -492,10 +513,12 @@ void DenseHll::serialize(char* output) { } } -DenseHll::DenseHll(const char* serialized, HashStringAllocator* allocator) - : deltas_{StlAllocator(allocator)}, - overflowBuckets_{StlAllocator(allocator)}, - overflowValues_{StlAllocator(allocator)} { +template +DenseHll::DenseHll(const char* serialized, TAllocator* allocator) + : allocator_(allocator), + deltas_{TStlAllocator(allocator)}, + overflowBuckets_{TStlAllocator(allocator)}, + overflowValues_{TStlAllocator(allocator)} { auto hll = deserialize(serialized); initialize(hll.indexBitLength); baseline_ = hll.baseline; @@ -525,7 +548,8 @@ DenseHll::DenseHll(const char* serialized, HashStringAllocator* allocator) } } -void DenseHll::mergeWith(const DenseHll& other) { +template +void DenseHll::mergeWith(const DenseHll& other) { VELOX_CHECK_EQ( indexBitLength_, other.indexBitLength_, @@ -539,7 +563,8 @@ void DenseHll::mergeWith(const DenseHll& other) { other.overflowValues_.data()}); } -void DenseHll::mergeWith(const char* serialized) { +template +void DenseHll::mergeWith(const char* serialized) { common::InputByteStream stream(serialized); auto version = stream.read(); @@ -561,7 +586,8 @@ void DenseHll::mergeWith(const char* serialized) { mergeWith({baseline, deltas, overflows, overflowBuckets, overflowValues}); } -std::pair DenseHll::computeNewValue( +template +std::pair DenseHll::computeNewValue( int8_t delta, int8_t otherDelta, int32_t bucket, @@ -585,7 +611,8 @@ std::pair DenseHll::computeNewValue( return {std::max(value1, value2), overflowEntry}; } -void DenseHll::mergeWith(const HllView& other) { +template +void DenseHll::mergeWith(const HllView& other) { // Number of 'delta' bytes that fit in a single SIMD batch. Each 'delta' byte // stores 2 4-bit deltas. constexpr auto batchSize = xsimd::batch::size; @@ -611,7 +638,10 @@ void DenseHll::mergeWith(const HllView& other) { adjustBaselineIfNeeded(); } -int32_t DenseHll::mergeWithSimd(const HllView& other, int8_t newBaseline) { +template +int32_t DenseHll::mergeWithSimd( + const HllView& other, + int8_t newBaseline) { const auto batchSize = xsimd::batch::size; const auto bucketMaskBatch = xsimd::broadcast(kBucketMask); @@ -751,7 +781,10 @@ int32_t DenseHll::mergeWithSimd(const HllView& other, int8_t newBaseline) { return baselineCount; } -int32_t DenseHll::mergeWithScalar(const HllView& other, int8_t newBaseline) { +template +int32_t DenseHll::mergeWithScalar( + const HllView& other, + int8_t newBaseline) { int32_t baselineCount = 0; int bucket = 0; @@ -787,8 +820,11 @@ int32_t DenseHll::mergeWithScalar(const HllView& other, int8_t newBaseline) { return baselineCount; } -int8_t -DenseHll::updateOverflow(int32_t index, int overflowEntry, int8_t delta) { +template +int8_t DenseHll::updateOverflow( + int32_t index, + int overflowEntry, + int8_t delta) { if (delta > kMaxDelta) { if (overflowEntry != -1) { // update existing overflow @@ -804,7 +840,8 @@ DenseHll::updateOverflow(int32_t index, int overflowEntry, int8_t delta) { return delta; } -void DenseHll::addOverflow(int32_t index, int8_t overflow) { +template +void DenseHll::addOverflow(int32_t index, int8_t overflow) { overflowBuckets_.resize(overflows_ + 1); overflowValues_.resize(overflows_ + 1); @@ -813,10 +850,17 @@ void DenseHll::addOverflow(int32_t index, int8_t overflow) { overflows_++; } -void DenseHll::removeOverflow(int overflowEntry) { +template +void DenseHll::removeOverflow(int overflowEntry) { // Remove existing overflow. overflowBuckets_[overflowEntry] = overflowBuckets_[overflows_ - 1]; overflowValues_[overflowEntry] = overflowValues_[overflows_ - 1]; overflows_--; } + +// Explicit template instantiation for both HashStringAllocator (default) and +// memory::MemoryPool +template class DenseHll; +template class DenseHll; + } // namespace facebook::velox::common::hll diff --git a/velox/common/hyperloglog/DenseHll.h b/velox/common/hyperloglog/DenseHll.h index b6b5f03f8cdf..6936b6d8085b 100644 --- a/velox/common/hyperloglog/DenseHll.h +++ b/velox/common/hyperloglog/DenseHll.h @@ -17,7 +17,45 @@ #include "velox/common/memory/HashStringAllocator.h" namespace facebook::velox::common::hll { -class SparseHll; + +class DenseHlls { + public: + /// Returns cardinality estimate from the specified serialized digest. + /// @param serialized Pointer to serialized DenseHll data + /// @return Estimated cardinality of the HyperLogLog + static int64_t cardinality(const char* serialized); + + /// Returns true if 'input' contains Presto DenseV2 format indicator. + /// @param input Pointer to serialized data to check + /// @return True if the data is in DenseV2 format, false otherwise + static bool canDeserialize(const char* input); + + /// Returns true if 'input' contains Presto DenseV2 format indicator and the + /// rest of the data matches HLL format: + /// 1 byte for version + /// 1 byte for index bit length, index bit length must be in [4,16] + /// 1 byte for baseline value + /// 2^(n-1) bytes for buckets, values in buckets must be in [0,63] + /// 2 bytes for # overflow buckets + /// 3 * #overflow buckets bytes for overflow buckets/values + /// More information here: + /// https://engineering.fb.com/2018/12/13/data-infrastructure/hyperloglog/ + /// @param input Pointer to serialized data to validate + /// @param size Size of the serialized data in bytes + /// @return True if the data is valid DenseV2 format, false otherwise + static bool canDeserialize(const char* input, int size); + + /// Extracts the index bit length from serialized DenseHll data. + /// @param input Pointer to serialized DenseHll data + /// @return The index bit length used in the serialized HLL + static int8_t deserializeIndexBitLength(const char* input); + + /// Returns an estimate of memory usage for DenseHll instance with the + /// specified number of bits per bucket. + /// @param indexBitLength Number of bits per bucket (must be in [4,16]) + /// @return Estimated memory usage in bytes + static int32_t estimateInMemorySize(int8_t indexBitLength); +}; /// HyperLogLog implementation using dense storage layout. /// The number of bits to use as bucket (indexBitLength) is specified by the @@ -26,18 +64,19 @@ class SparseHll; /// /// Memory usage: 2 ^ (indexBitLength - 1) bytes. 2KB for indexBitLength of 12 /// which provides max standard error of 0.023. +template class DenseHll { public: - DenseHll(int8_t indexBitLength, HashStringAllocator* allocator); + template + using TStlAllocator = typename TAllocator::template TStlAllocator; + + DenseHll(int8_t indexBitLength, TAllocator* allocator); - DenseHll(const char* serialized, HashStringAllocator* allocator); + DenseHll(const char* serialized, TAllocator* allocator); /// Creates an uninitialized instance that doesn't allcate any significant /// memory. The caller must call initialize before using the HLL. - explicit DenseHll(HashStringAllocator* allocator) - : deltas_{StlAllocator(allocator)}, - overflowBuckets_{StlAllocator(allocator)}, - overflowValues_{StlAllocator(allocator)} {} + explicit DenseHll(TAllocator* allocator); /// Allocates memory that can fit 2 ^ indexBitLength buckets. void initialize(int8_t indexBitLength); @@ -55,28 +94,9 @@ class DenseHll { int64_t cardinality() const; - static int64_t cardinality(const char* serialized); - /// Serializes internal state using Presto DenseV2 format. void serialize(char* output); - /// Returns true if 'input' contains Presto DenseV2 format indicator. - static bool canDeserialize(const char* input); - - /// Returns true if 'input' contains Presto DenseV2 format indicator and the - /// rest of the data matches HLL format: - /// 1 byte for version - /// 1 byte for index bit length, index bit length must be in [4,16] - /// 1 byte for baseline value - /// 2^(n-1) bytes for buckets, values in buckets must be in [0,63] - /// 2 bytes for # overflow buckets - /// 3 * #overflow buckets bytes for overflow buckets/values - /// More information here: - /// https://engineering.fb.com/2018/12/13/data-infrastructure/hyperloglog/ - static bool canDeserialize(const char* input, int size); - - static int8_t deserializeIndexBitLength(const char* input); - /// Returns the size of the serialized state without serialising. int32_t serializedSize() const; @@ -86,10 +106,6 @@ class DenseHll { void mergeWith(const char* serialized); - /// Returns an estimate of memory usage for DenseHll instance with the - /// specified number of bits per bucket. - static int32_t estimateInMemorySize(int8_t indexBitLength); - private: int8_t getDelta(int32_t index) const; @@ -147,20 +163,19 @@ class DenseHll { /// Number of zero deltas. int32_t baselineCount_; + TAllocator* allocator_; + /// Per-bucket values represented as deltas from the baseline_. Each entry /// stores 2 values, 4 bits each. The maximum value that can be stored is 15. /// Larger values are stored in a separate overflow list. - std::vector> deltas_; - - /// Number of overflowing values, e.g. values where delta from baseline is - /// greater than 15. + std::vector> deltas_; int16_t overflows_{0}; /// List of buckets with overflowing values. - std::vector> overflowBuckets_; + std::vector> overflowBuckets_; /// Overflowing values stored as deltas from the deltas: value - 15 - /// baseline. - std::vector> overflowValues_; + std::vector> overflowValues_; }; } // namespace facebook::velox::common::hll diff --git a/velox/common/hyperloglog/SparseHll.cpp b/velox/common/hyperloglog/SparseHll.cpp index 27d290dd10ef..ed1fc10800ea 100644 --- a/velox/common/hyperloglog/SparseHll.cpp +++ b/velox/common/hyperloglog/SparseHll.cpp @@ -34,9 +34,8 @@ inline uint32_t decodeValue(uint32_t entry) { return entry & ((1 << kValueBitLength) - 1); } -int searchIndex( - uint32_t index, - const std::vector>& entries) { +template +int searchIndex(uint32_t index, const VectorType& entries) { int low = 0; int high = entries.size() - 1; @@ -69,7 +68,65 @@ common::InputByteStream initializeInputStream(const char* serialized) { } } // namespace -bool SparseHll::insertHash(uint64_t hash) { +// Static utility functions implementation +int64_t SparseHlls::cardinality(const char* serialized) { + static const int kTotalBuckets = 1 << kIndexBitLength; + + auto stream = initializeInputStream(serialized); + auto size = stream.read(); + + int zeroBuckets = kTotalBuckets - size; + return std::round(linearCounting(zeroBuckets, kTotalBuckets)); +} + +std::string SparseHlls::serializeEmpty(int8_t indexBitLength) { + static const size_t kSize = 4; + + std::string serialized; + serialized.resize(kSize); + + common::OutputByteStream stream(serialized.data()); + stream.appendOne(kPrestoSparseV2); + stream.appendOne(indexBitLength); + stream.appendOne(static_cast(0)); + return serialized; +} + +bool SparseHlls::canDeserialize(const char* input) { + return *reinterpret_cast(input) == kPrestoSparseV2; +} + +int8_t SparseHlls::deserializeIndexBitLength(const char* input) { + common::InputByteStream stream(input); + stream.read(); // Skip version + return stream.read(); // Return indexBitLength +} + +// Template method implementations +template +SparseHll::SparseHll(TAllocator* allocator) + : allocator_(allocator), entries_{TStlAllocator(allocator)} {} + +template +SparseHll::SparseHll(const char* serialized, TAllocator* allocator) + : allocator_(allocator), entries_{TStlAllocator(allocator)} { + common::InputByteStream stream(serialized); + auto version = stream.read(); + VELOX_CHECK_EQ(kPrestoSparseV2, version); + + // Skip indexBitLength from serialized data - we use fixed kIndexBitLength + // internally + stream.read(); + + auto size = stream.read(); + entries_.resize(size); + for (auto i = 0; i < size; i++) { + entries_[i] = stream.read(); + } +} + +template +bool SparseHll::insertHash(uint64_t hash) { auto index = computeIndex(hash, kIndexBitLength); auto value = numberOfLeadingZeros(hash, kIndexBitLength); @@ -88,29 +145,21 @@ bool SparseHll::insertHash(uint64_t hash) { return overLimit(); } -int64_t SparseHll::cardinality() const { +template +int64_t SparseHll::cardinality() const { // Estimate the cardinality using linear counting over the theoretical // 2^kIndexBitLength buckets available due to the fact that we're // recording the raw leading kIndexBitLength of the hash. This produces // much better precision while in the sparse regime. - static const int kTotalBuckets = 1 << kIndexBitLength; + const int kTotalBuckets = 1 << kIndexBitLength; int zeroBuckets = kTotalBuckets - entries_.size(); return std::round(linearCounting(zeroBuckets, kTotalBuckets)); } -// static -int64_t SparseHll::cardinality(const char* serialized) { - static const int kTotalBuckets = 1 << kIndexBitLength; - - auto stream = initializeInputStream(serialized); - auto size = stream.read(); - - int zeroBuckets = kTotalBuckets - size; - return std::round(linearCounting(zeroBuckets, kTotalBuckets)); -} - -void SparseHll::serialize(int8_t indexBitLength, char* output) const { +template +void SparseHll::serialize(int8_t indexBitLength, char* output) + const { common::OutputByteStream stream(output); stream.appendOne(kPrestoSparseV2); stream.appendOne(indexBitLength); @@ -120,75 +169,54 @@ void SparseHll::serialize(int8_t indexBitLength, char* output) const { } } -// static -std::string SparseHll::serializeEmpty(int8_t indexBitLength) { - static const size_t kSize = 4; - - std::string serialized; - serialized.resize(kSize); - - common::OutputByteStream stream(serialized.data()); - stream.appendOne(kPrestoSparseV2); - stream.appendOne(indexBitLength); - stream.appendOne(static_cast(0)); - return serialized; -} - -// static -bool SparseHll::canDeserialize(const char* input) { - return *reinterpret_cast(input) == kPrestoSparseV2; -} - -int32_t SparseHll::serializedSize() const { +template +int32_t SparseHll::serializedSize() const { return 1 /* version */ + 1 /* indexBitLength */ + 2 /* number of entries */ + entries_.size() * 4; } -int32_t SparseHll::inMemorySize() const { +template +int32_t SparseHll::inMemorySize() const { return sizeof(uint32_t) * entries_.size(); } -SparseHll::SparseHll(const char* serialized, HashStringAllocator* allocator) - : entries_{StlAllocator(allocator)} { - auto stream = initializeInputStream(serialized); - - auto size = stream.read(); - entries_.resize(size); - for (auto i = 0; i < size; i++) { - entries_[i] = stream.read(); - } -} - -void SparseHll::mergeWith(const SparseHll& other) { +template +void SparseHll::mergeWith(const SparseHll& other) { auto size = other.entries_.size(); // This check prevents merge aggregation from being performed on - // empty_approx_set(), an empty HyperLogLog. The merge function typically does - // not take an empty HyperLogLog structure as an argument. + // empty_approx_set(), an empty HyperLogLog. The merge function typically + // does not take an empty HyperLogLog structure as an argument. if (size) { mergeWith(size, other.entries_.data()); } } -void SparseHll::mergeWith(const char* serialized) { +template +void SparseHll::mergeWith(const char* serialized) { auto stream = initializeInputStream(serialized); auto size = stream.read(); // This check prevents merge aggregation from being performed on - // empty_approx_set(), an empty HyperLogLog. The merge function typically does - // not take an empty HyperLogLog structure as an argument. + // empty_approx_set(), an empty HyperLogLog. The merge function typically + // does not take an empty HyperLogLog structure as an argument. if (size) { mergeWith( size, reinterpret_cast(serialized + stream.offset())); } } -void SparseHll::mergeWith(size_t otherSize, const uint32_t* otherEntries) { +template +void SparseHll::mergeWith( + size_t otherSize, + const uint32_t* otherEntries) { VELOX_CHECK_GT(otherSize, 0); auto size = entries_.size(); - std::vector merged(size + otherSize); + + auto merged = std::vector>( + size + otherSize, TStlAllocator(allocator_)); int pos = 0; int leftPos = 0; @@ -223,7 +251,8 @@ void SparseHll::mergeWith(size_t otherSize, const uint32_t* otherEntries) { } } -void SparseHll::verify() const { +template +void SparseHll::verify() const { if (entries_.size() <= 1) { return; } @@ -236,11 +265,11 @@ void SparseHll::verify() const { } } -void SparseHll::toDense(DenseHll& denseHll) const { +template +void SparseHll::toDense(DenseHll& denseHll) const { auto indexBitLength = denseHll.indexBitLength(); - for (auto i = 0; i < entries_.size(); i++) { - auto entry = entries_[i]; + for (auto entry : entries_) { auto index = entry >> (32 - indexBitLength); auto shiftedValue = entry << indexBitLength; auto zeros = shiftedValue == 0 ? 32 : __builtin_clz(shiftedValue); @@ -257,4 +286,9 @@ void SparseHll::toDense(DenseHll& denseHll) const { } } +// Explicit template instantiation for HashStringAllocator (default) +template class SparseHll; +// Explicit template instantiation for memory::MemoryPool +template class SparseHll; + } // namespace facebook::velox::common::hll diff --git a/velox/common/hyperloglog/SparseHll.h b/velox/common/hyperloglog/SparseHll.h index e881b5776270..61a3cb8cb29a 100644 --- a/velox/common/hyperloglog/SparseHll.h +++ b/velox/common/hyperloglog/SparseHll.h @@ -18,15 +18,42 @@ #include "velox/common/memory/HashStringAllocator.h" namespace facebook::velox::common::hll { + +class SparseHlls { + public: + /// Returns cardinality estimate from the specified serialized digest. + /// @param serialized Pointer to serialized SparseHll data + /// @return Estimated cardinality of the HyperLogLog + static int64_t cardinality(const char* serialized); + + /// Returns true if 'input' has Presto SparseV2 format. + /// @param input Pointer to serialized data to check + /// @return True if the data is in SparseV2 format, false otherwise + static bool canDeserialize(const char* input); + + /// Creates an empty serialized SparseHll with the specified index bit length. + /// @param indexBitLength Number of bits for indexing (must be in [4,16]) + /// @return Serialized empty SparseHll as a string + static std::string serializeEmpty(int8_t indexBitLength); + + /// Extracts the index bit length from serialized SparseHll data. + /// @param input Pointer to serialized SparseHll data + /// @return The index bit length used in the serialized HLL + static int8_t deserializeIndexBitLength(const char* input); +}; + /// HyperLogLog implementation using sparse storage layout. /// It uses 26-bit buckets and provides high accuracy for low cardinalities. /// Memory usage: 4 bytes for each observed bucket. +template class SparseHll { public: - explicit SparseHll(HashStringAllocator* allocator) - : entries_{StlAllocator(allocator)} {} + template + using TStlAllocator = typename TAllocator::template TStlAllocator; + + explicit SparseHll(TAllocator* allocator); - SparseHll(const char* serialized, HashStringAllocator* allocator); + SparseHll(const char* serialized, TAllocator* allocator); void setSoftMemoryLimit(uint32_t softMemoryLimit) { softNumEntriesLimit_ = softMemoryLimit / 4; @@ -42,17 +69,9 @@ class SparseHll { int64_t cardinality() const; - /// Returns cardinality estimate from the specified serialized digest. - static int64_t cardinality(const char* serialized); - /// Serializes internal state using Presto SparseV2 format. void serialize(int8_t indexBitLength, char* output) const; - static std::string serializeEmpty(int8_t indexBitLength); - - /// Returns true if 'input' has Presto SparseV2 format. - static bool canDeserialize(const char* input); - /// Returns the size of the serialized state without serialising. int32_t serializedSize() const; @@ -63,7 +82,7 @@ class SparseHll { void mergeWith(const char* serialized); /// Merges state into provided instance of DenseHll. - void toDense(DenseHll& denseHll) const; + void toDense(DenseHll& denseHll) const; /// Returns current memory usage. int32_t inMemorySize() const; @@ -84,8 +103,8 @@ class SparseHll { /// A list of observed buckets. Each entry is a 32 bit integer encoding 26-bit /// bucket and 6-bit value (number of zeros in the input hash after the bucket /// + 1). - std::vector> entries_; - + TAllocator* allocator_; + std::vector> entries_; /// Number of entries that can be stored before reaching soft memory limit. uint32_t softNumEntriesLimit_{0}; }; diff --git a/velox/common/hyperloglog/benchmarks/DenseHll.cpp b/velox/common/hyperloglog/benchmarks/DenseHll.cpp index 7233280f1d66..0bd112e0a720 100644 --- a/velox/common/hyperloglog/benchmarks/DenseHll.cpp +++ b/velox/common/hyperloglog/benchmarks/DenseHll.cpp @@ -49,7 +49,7 @@ class DenseHllBenchmark { folly::BenchmarkSuspender suspender; HashStringAllocator allocator(pool_); - common::hll::DenseHll hll(hashBits, &allocator); + common::hll::DenseHll<> hll(hashBits, &allocator); suspender.dismiss(); @@ -61,7 +61,7 @@ class DenseHllBenchmark { private: std::string makeSerializedHll(int hashBits, int32_t step) { HashStringAllocator allocator(pool_); - common::hll::DenseHll hll(hashBits, &allocator); + common::hll::DenseHll<> hll(hashBits, &allocator); for (int32_t i = 0; i < 1'000'000; ++i) { auto hash = hashOne(i * step); hll.insertHash(hash); @@ -69,7 +69,7 @@ class DenseHllBenchmark { return serialize(hll); } - static std::string serialize(common::hll::DenseHll& denseHll) { + static std::string serialize(common::hll::DenseHll<>& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); diff --git a/velox/common/hyperloglog/tests/DenseHllTest.cpp b/velox/common/hyperloglog/tests/DenseHllTest.cpp index c688a6918c47..3f420ae2c4ae 100644 --- a/velox/common/hyperloglog/tests/DenseHllTest.cpp +++ b/velox/common/hyperloglog/tests/DenseHllTest.cpp @@ -15,9 +15,8 @@ */ #include "velox/common/hyperloglog/DenseHll.h" +#include #include -#include -#include #include #define XXH_INLINE_ALL @@ -34,22 +33,27 @@ uint64_t hashOne(T value) { return XXH64(&value, sizeof(value), 0); } -class DenseHllTest : public ::testing::TestWithParam { +template +class DenseHllTest : public ::testing::Test { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } - DenseHll roundTrip(DenseHll& hll) { - auto size = hll.serializedSize(); - std::string serialized; - serialized.resize(size); - hll.serialize(serialized.data()); + void SetUp() override { + if constexpr (std::is_same_v) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } - return DenseHll(serialized.data(), &allocator_); + DenseHll roundTrip(DenseHll& hll) { + auto serialized = this->serialize(hll); + return DenseHll(serialized.data(), allocator_); } - std::string serialize(DenseHll& denseHll) { + std::string serialize(DenseHll& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -58,23 +62,19 @@ class DenseHllTest : public ::testing::TestWithParam { } template - void testMergeWith( - int8_t indexBitLength, - const std::vector& left, - const std::vector& right) { - testMergeWith(indexBitLength, left, right, false); - testMergeWith(indexBitLength, left, right, true); + void testMergeWith(const std::vector& left, const std::vector& right) { + testMergeWith(left, right, false); + testMergeWith(left, right, true); } template void testMergeWith( - int8_t indexBitLength, const std::vector& left, const std::vector& right, bool serialized) { - DenseHll hllLeft{indexBitLength, &allocator_}; - DenseHll hllRight{indexBitLength, &allocator_}; - DenseHll expected{indexBitLength, &allocator_}; + DenseHll hllLeft{11, allocator_}; + DenseHll hllRight{11, allocator_}; + DenseHll expected{11, allocator_}; for (auto value : left) { auto hash = hashOne(value); @@ -89,30 +89,51 @@ class DenseHllTest : public ::testing::TestWithParam { } if (serialized) { - auto serializedRight = serialize(hllRight); + auto serializedRight = this->serialize(hllRight); hllLeft.mergeWith(serializedRight.data()); } else { hllLeft.mergeWith(hllRight); } ASSERT_EQ(hllLeft.cardinality(), expected.cardinality()); - ASSERT_EQ(serialize(hllLeft), serialize(expected)); + ASSERT_EQ(this->serialize(hllLeft), this->serialize(expected)); - auto hllLeftSerialized = serialize(hllLeft); + auto hllLeftSerialized = this->serialize(hllLeft); ASSERT_EQ( - DenseHll::cardinality(hllLeftSerialized.data()), + DenseHlls::cardinality(hllLeftSerialized.data()), expected.cardinality()); } std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; - HashStringAllocator allocator_{pool_.get()}; + HashStringAllocator hsa_{pool_.get()}; + TAllocator* allocator_; }; -TEST_P(DenseHllTest, basic) { - int8_t indexBitLength = GetParam(); +using AllocatorTypes = + ::testing::Types; + +class NameGenerator { + public: + template + static std::string GetName(int) { + if constexpr (std::is_same_v) { + return "hsa"; + } else if constexpr (std::is_same_v) { + return "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + } +}; + +TYPED_TEST_SUITE(DenseHllTest, AllocatorTypes, NameGenerator); + +TYPED_TEST(DenseHllTest, basic) { + int8_t indexBitLength = 11; + DenseHll denseHll{indexBitLength, this->allocator_}; - DenseHll denseHll{indexBitLength, &allocator_}; for (int i = 0; i < 1'000; i++) { auto value = i % 17; auto hash = hashOne(value); @@ -131,31 +152,29 @@ TEST_P(DenseHllTest, basic) { ASSERT_EQ(expectedCardinality, denseHll.cardinality()); - DenseHll deserialized = roundTrip(denseHll); + DenseHll deserialized = this->roundTrip(denseHll); ASSERT_EQ(expectedCardinality, deserialized.cardinality()); - auto serialized = serialize(denseHll); - ASSERT_EQ(expectedCardinality, DenseHll::cardinality(serialized.data())); + auto serialized = this->serialize(denseHll); + ASSERT_EQ(expectedCardinality, DenseHlls::cardinality(serialized.data())); } -TEST_P(DenseHllTest, highCardinality) { - int8_t indexBitLength = GetParam(); +TYPED_TEST(DenseHllTest, highCardinality) { + int8_t indexBitLength = 11; + DenseHll denseHll{indexBitLength, this->allocator_}; - DenseHll denseHll{indexBitLength, &allocator_}; for (int i = 0; i < 10'000'000; i++) { auto hash = hashOne(i); denseHll.insertHash(hash); } - if (indexBitLength >= 11) { - ASSERT_NEAR(10'000'000, denseHll.cardinality(), 150'000); - } + ASSERT_NEAR(10'000'000, denseHll.cardinality(), 150'000); - DenseHll deserialized = roundTrip(denseHll); + auto deserialized = this->roundTrip(denseHll); ASSERT_EQ(denseHll.cardinality(), deserialized.cardinality()); - auto serialized = serialize(denseHll); - ASSERT_EQ(denseHll.cardinality(), DenseHll::cardinality(serialized.data())); + auto serialized = this->serialize(denseHll); + ASSERT_EQ(denseHll.cardinality(), DenseHlls::cardinality(serialized.data())); } namespace { @@ -170,7 +189,181 @@ std::vector sequence(T start, T end) { } } // namespace -TEST_P(DenseHllTest, canDeserialize) { +TYPED_TEST(DenseHllTest, mergeWith) { + // small, non-overlapping + this->testMergeWith(sequence(0, 100), sequence(100, 200)); + this->testMergeWith(sequence(100, 200), sequence(0, 100)); + + // small, overlapping + this->testMergeWith(sequence(0, 100), sequence(50, 150)); + this->testMergeWith(sequence(50, 150), sequence(0, 100)); + + // small, same + this->testMergeWith(sequence(0, 100), sequence(0, 100)); + + // large, non-overlapping + this->testMergeWith(sequence(0, 20'000), sequence(20'000, 40'000)); + this->testMergeWith(sequence(20'000, 40'000), sequence(0, 20'000)); + + // large, overlapping + this->testMergeWith(sequence(0, 2'000'000), sequence(1'000'000, 3'000'000)); + this->testMergeWith(sequence(1'000'000, 3'000'000), sequence(0, 2'000'000)); + + // large, same + this->testMergeWith(sequence(0, 2'000'000), sequence(0, 2'000'000)); +} + +// Separate test class for testing various index bit lengths +template +struct AllocatorWithIndexBits { + using AllocatorType = TAllocator; + static constexpr int8_t indexBitLength() { + return IndexBitLength; + } +}; + +template +class DenseHllMergeTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } + + std::string serialize(DenseHll& denseHll) { + auto size = denseHll.serializedSize(); + std::string serialized; + serialized.resize(size); + denseHll.serialize(serialized.data()); + return serialized; + } + + template + void testMergeWith( + int8_t indexBitLength, + const std::vector& left, + const std::vector& right) { + testMergeWith(indexBitLength, left, right, false); + testMergeWith(indexBitLength, left, right, true); + } + + template + void testMergeWith( + int8_t indexBitLength, + const std::vector& left, + const std::vector& right, + bool serialized) { + DenseHll hllLeft{indexBitLength, allocator_}; + DenseHll hllRight{indexBitLength, allocator_}; + DenseHll expected{indexBitLength, allocator_}; + + for (auto value : left) { + auto hash = hashOne(value); + hllLeft.insertHash(hash); + expected.insertHash(hash); + } + + for (auto value : right) { + auto hash = hashOne(value); + hllRight.insertHash(hash); + expected.insertHash(hash); + } + + if (serialized) { + auto serializedRight = this->serialize(hllRight); + hllLeft.mergeWith(serializedRight.data()); + } else { + hllLeft.mergeWith(hllRight); + } + + ASSERT_EQ(hllLeft.cardinality(), expected.cardinality()); + ASSERT_EQ(this->serialize(hllLeft), this->serialize(expected)); + + auto hllLeftSerialized = this->serialize(hllLeft); + ASSERT_EQ( + DenseHlls::cardinality(hllLeftSerialized.data()), + expected.cardinality()); + } + + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool()}; + HashStringAllocator hsa_{pool_.get()}; + typename TParam::AllocatorType* allocator_; +}; + +using DenseHllMergeTestParams = ::testing::Types< + // HashStringAllocator with all index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + // MemoryPool with all index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits>; + +class ComprehensiveNameGenerator { + public: + template + static std::string GetName(int) { + std::string allocatorName; + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocatorName = "hsa"; + } else if constexpr (std::is_same_v< + typename TParam::AllocatorType, + memory::MemoryPool>) { + allocatorName = "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + return fmt::format("{}_{}", allocatorName, TParam::indexBitLength()); + } +}; + +TYPED_TEST_SUITE( + DenseHllMergeTest, + DenseHllMergeTestParams, + ComprehensiveNameGenerator); + +class DenseHllCanDeserializeTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } +}; + +TEST_F(DenseHllCanDeserializeTest, canDeserialize) { // These are not valid HLL but all pass canDeserialize version only check. std::vector invalidStrings{ "AxIRESUhEzNBFCQWYxEjIzI1ISURNidCMlViIjOSNyATBhYSIiJDUyMBIlcSMDUiEUEiESM1ITckQkQTMSMhMyQx", @@ -181,9 +374,9 @@ TEST_P(DenseHllTest, canDeserialize) { for (folly::StringPiece& invalidString : invalidStrings) { auto invalidHll = Base64::decode(invalidString); - EXPECT_TRUE(DenseHll::canDeserialize(invalidHll.c_str())); + EXPECT_TRUE(DenseHlls::canDeserialize(invalidHll.c_str())); EXPECT_FALSE( - DenseHll::canDeserialize(invalidHll.c_str(), invalidHll.length())); + DenseHlls::canDeserialize(invalidHll.c_str(), invalidHll.length())); } std::vector validStrings{ @@ -192,40 +385,38 @@ TEST_P(DenseHllTest, canDeserialize) { for (folly::StringPiece& validString : validStrings) { auto validHll = Base64::decode(validString); - EXPECT_TRUE(DenseHll::canDeserialize(validHll.c_str())); - EXPECT_TRUE(DenseHll::canDeserialize(validHll.c_str(), validHll.length())); + EXPECT_TRUE(DenseHlls::canDeserialize(validHll.c_str())); + EXPECT_TRUE(DenseHlls::canDeserialize(validHll.c_str(), validHll.length())); } } -TEST_P(DenseHllTest, mergeWith) { - int8_t indexBitLength = GetParam(); +TYPED_TEST(DenseHllMergeTest, mergeWith) { + int8_t indexBitLength = TypeParam::indexBitLength(); // small, non-overlapping - testMergeWith(indexBitLength, sequence(0, 100), sequence(100, 200)); - testMergeWith(indexBitLength, sequence(100, 200), sequence(0, 100)); + this->testMergeWith(indexBitLength, sequence(0, 100), sequence(100, 200)); + this->testMergeWith(indexBitLength, sequence(100, 200), sequence(0, 100)); // small, overlapping - testMergeWith(indexBitLength, sequence(0, 100), sequence(50, 150)); - testMergeWith(indexBitLength, sequence(50, 150), sequence(0, 100)); + this->testMergeWith(indexBitLength, sequence(0, 100), sequence(50, 150)); + this->testMergeWith(indexBitLength, sequence(50, 150), sequence(0, 100)); // small, same - testMergeWith(indexBitLength, sequence(0, 100), sequence(0, 100)); + this->testMergeWith(indexBitLength, sequence(0, 100), sequence(0, 100)); // large, non-overlapping - testMergeWith(indexBitLength, sequence(0, 20'000), sequence(20'000, 40'000)); - testMergeWith(indexBitLength, sequence(20'000, 40'000), sequence(0, 20'000)); + this->testMergeWith( + indexBitLength, sequence(0, 20'000), sequence(20'000, 40'000)); + this->testMergeWith( + indexBitLength, sequence(20'000, 40'000), sequence(0, 20'000)); // large, overlapping - testMergeWith( + this->testMergeWith( indexBitLength, sequence(0, 2'000'000), sequence(1'000'000, 3'000'000)); - testMergeWith( + this->testMergeWith( indexBitLength, sequence(1'000'000, 3'000'000), sequence(0, 2'000'000)); // large, same - testMergeWith(indexBitLength, sequence(0, 2'000'000), sequence(0, 2'000'000)); + this->testMergeWith( + indexBitLength, sequence(0, 2'000'000), sequence(0, 2'000'000)); } - -INSTANTIATE_TEST_SUITE_P( - DenseHllTest, - DenseHllTest, - ::testing::Values(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)); diff --git a/velox/common/hyperloglog/tests/SparseHllTest.cpp b/velox/common/hyperloglog/tests/SparseHllTest.cpp index 299e2c8aebd9..9d2ae9238d0d 100644 --- a/velox/common/hyperloglog/tests/SparseHllTest.cpp +++ b/velox/common/hyperloglog/tests/SparseHllTest.cpp @@ -18,6 +18,7 @@ #define XXH_INLINE_ALL #include +#include #include using namespace facebook::velox; @@ -28,12 +29,21 @@ uint64_t hashOne(T value) { return XXH64(&value, sizeof(value), 0); } +template class SparseHllTest : public ::testing::Test { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } + void SetUp() override { + if constexpr (std::is_same_v) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } + template void testMergeWith(const std::vector& left, const std::vector& right) { testMergeWith(left, right, false); @@ -45,9 +55,9 @@ class SparseHllTest : public ::testing::Test { const std::vector& left, const std::vector& right, bool serialized) { - SparseHll hllLeft{&allocator_}; - SparseHll hllRight{&allocator_}; - SparseHll expected{&allocator_}; + SparseHll hllLeft{allocator_}; + SparseHll hllRight{allocator_}; + SparseHll expected{allocator_}; for (auto value : left) { auto hash = hashOne(value); @@ -77,16 +87,20 @@ class SparseHllTest : public ::testing::Test { auto hllLeftSerialized = serialize(11, hllLeft); ASSERT_EQ( - SparseHll::cardinality(hllLeftSerialized.data()), + SparseHlls::cardinality(hllLeftSerialized.data()), expected.cardinality()); } - SparseHll roundTrip(SparseHll& hll) { - auto serialized = serialize(11, hll); - return SparseHll(serialized.data(), &allocator_); + SparseHll roundTrip( + SparseHll& hll, + int8_t indexBitLength = 11) { + auto serialized = serialize(indexBitLength, hll); + return SparseHll(serialized.data(), allocator_); } - std::string serialize(int8_t indexBitLength, const SparseHll& sparseHll) { + std::string serialize( + int8_t indexBitLength, + const SparseHll& sparseHll) { auto size = sparseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -94,7 +108,7 @@ class SparseHllTest : public ::testing::Test { return serialized; } - std::string serialize(DenseHll& denseHll) { + std::string serialize(DenseHll& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -104,11 +118,32 @@ class SparseHllTest : public ::testing::Test { std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; - HashStringAllocator allocator_{pool_.get()}; + HashStringAllocator hsa_{pool_.get()}; + TAllocator* allocator_; +}; + +using AllocatorTypes = + ::testing::Types; + +class NameGenerator { + public: + template + static std::string GetName(int) { + if constexpr (std::is_same_v) { + return "hsa"; + } else if constexpr (std::is_same_v) { + return "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + } }; -TEST_F(SparseHllTest, basic) { - SparseHll sparseHll{&allocator_}; +TYPED_TEST_SUITE(SparseHllTest, AllocatorTypes, NameGenerator); + +TYPED_TEST(SparseHllTest, basic) { + SparseHll sparseHll{this->allocator_}; for (int i = 0; i < 1'000; i++) { auto value = i % 17; auto hash = hashOne(value); @@ -118,16 +153,16 @@ TEST_F(SparseHllTest, basic) { sparseHll.verify(); ASSERT_EQ(17, sparseHll.cardinality()); - auto deserialized = roundTrip(sparseHll); + auto deserialized = this->roundTrip(sparseHll); deserialized.verify(); ASSERT_EQ(17, deserialized.cardinality()); - auto serialized = serialize(11, sparseHll); - ASSERT_EQ(17, SparseHll::cardinality(serialized.data())); + auto serialized = this->serialize(11, sparseHll); + ASSERT_EQ(17, SparseHlls::cardinality(serialized.data())); } -TEST_F(SparseHllTest, highCardinality) { - SparseHll sparseHll{&allocator_}; +TYPED_TEST(SparseHllTest, highCardinality) { + SparseHll sparseHll{this->allocator_}; for (int i = 0; i < 1'000; i++) { auto hash = hashOne(i); sparseHll.insertHash(hash); @@ -136,12 +171,12 @@ TEST_F(SparseHllTest, highCardinality) { sparseHll.verify(); ASSERT_EQ(1'000, sparseHll.cardinality()); - auto deserialized = roundTrip(sparseHll); + auto deserialized = this->roundTrip(sparseHll); deserialized.verify(); ASSERT_EQ(1'000, deserialized.cardinality()); - auto serialized = serialize(11, sparseHll); - ASSERT_EQ(1'000, SparseHll::cardinality(serialized.data())); + auto serialized = this->serialize(11, sparseHll); + ASSERT_EQ(1'000, SparseHlls::cardinality(serialized.data())); } namespace { @@ -156,30 +191,80 @@ std::vector sequence(T start, T end) { } } // namespace -TEST_F(SparseHllTest, mergeWith) { +TYPED_TEST(SparseHllTest, mergeWith) { // with overlap - testMergeWith(sequence(0, 100), sequence(50, 150)); - testMergeWith(sequence(50, 150), sequence(0, 100)); + this->testMergeWith(sequence(0, 100), sequence(50, 150)); + this->testMergeWith(sequence(50, 150), sequence(0, 100)); // no overlap - testMergeWith(sequence(0, 100), sequence(200, 300)); - testMergeWith(sequence(200, 300), sequence(0, 100)); + this->testMergeWith(sequence(0, 100), sequence(200, 300)); + this->testMergeWith(sequence(200, 300), sequence(0, 100)); // idempotent - testMergeWith(sequence(0, 100), sequence(0, 100)); + this->testMergeWith(sequence(0, 100), sequence(0, 100)); // empty sequence - testMergeWith(sequence(0, 100), {}); - testMergeWith({}, sequence(100, 300)); + this->testMergeWith(sequence(0, 100), {}); + this->testMergeWith({}, sequence(100, 300)); +} + +TYPED_TEST(SparseHllTest, toDense) { + int8_t indexBitLength = 11; + + SparseHll sparseHll{this->allocator_}; + DenseHll expectedHll{indexBitLength, this->allocator_}; + for (int i = 0; i < 1'000; i++) { + auto hash = hashOne(i); + sparseHll.insertHash(hash); + expectedHll.insertHash(hash); + } + + DenseHll denseHll{indexBitLength, this->allocator_}; + sparseHll.toDense(denseHll); + ASSERT_EQ(denseHll.cardinality(), expectedHll.cardinality()); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); +} + +TYPED_TEST(SparseHllTest, testNumberOfZeros) { + int8_t indexBitLength = 11; + for (int i = 0; i < 64 - indexBitLength; ++i) { + auto hash = 1ull << i; + SparseHll sparseHll(this->allocator_); + sparseHll.insertHash(hash); + DenseHll expectedHll(indexBitLength, this->allocator_); + expectedHll.insertHash(hash); + DenseHll denseHll(indexBitLength, this->allocator_); + sparseHll.toDense(denseHll); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); + } } -class SparseHllToDenseTest : public ::testing::TestWithParam { +template +struct AllocatorWithIndexBits { + using AllocatorType = TAllocator; + static constexpr int8_t indexBitLength() { + return IndexBitLength; + } +}; + +template +class SparseHllToDenseTest : public ::testing::Test { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } - std::string serialize(DenseHll& denseHll) { + void SetUp() override { + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } + + std::string serialize(DenseHll& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -189,41 +274,93 @@ class SparseHllToDenseTest : public ::testing::TestWithParam { std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; - HashStringAllocator allocator_{pool_.get()}; + HashStringAllocator hsa_{pool_.get()}; + typename TParam::AllocatorType* allocator_; }; -TEST_P(SparseHllToDenseTest, toDense) { - int8_t indexBitLength = GetParam(); +using SparseHllToDenseTestParams = ::testing::Types< + // HashStringAllocator with various index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + // MemoryPool with various index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits>; + +class ToDenseNameGenerator { + public: + template + static std::string GetName(int) { + std::string allocatorName; + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocatorName = "hsa"; + } else if constexpr (std::is_same_v< + typename TParam::AllocatorType, + memory::MemoryPool>) { + allocatorName = "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + return fmt::format("{}_{}", allocatorName, TParam::indexBitLength()); + } +}; - SparseHll sparseHll{&allocator_}; - DenseHll expectedHll{indexBitLength, &allocator_}; +TYPED_TEST_SUITE( + SparseHllToDenseTest, + SparseHllToDenseTestParams, + ToDenseNameGenerator); + +TYPED_TEST(SparseHllToDenseTest, toDense) { + int8_t indexBitLength = TypeParam::indexBitLength(); + + SparseHll sparseHll{this->allocator_}; + DenseHll expectedHll{indexBitLength, this->allocator_}; for (int i = 0; i < 1'000; i++) { auto hash = hashOne(i); sparseHll.insertHash(hash); expectedHll.insertHash(hash); } - DenseHll denseHll{indexBitLength, &allocator_}; + DenseHll denseHll{indexBitLength, this->allocator_}; sparseHll.toDense(denseHll); ASSERT_EQ(denseHll.cardinality(), expectedHll.cardinality()); - ASSERT_EQ(serialize(denseHll), serialize(expectedHll)); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); } -TEST_P(SparseHllToDenseTest, testNumberOfZeros) { - auto indexBitLength = GetParam(); +TYPED_TEST(SparseHllToDenseTest, testNumberOfZeros) { + auto indexBitLength = TypeParam::indexBitLength(); for (int i = 0; i < 64 - indexBitLength; ++i) { auto hash = 1ull << i; - SparseHll sparseHll(&allocator_); + SparseHll sparseHll(this->allocator_); sparseHll.insertHash(hash); - DenseHll expectedHll(indexBitLength, &allocator_); + DenseHll expectedHll(indexBitLength, this->allocator_); expectedHll.insertHash(hash); - DenseHll denseHll(indexBitLength, &allocator_); + DenseHll denseHll(indexBitLength, this->allocator_); sparseHll.toDense(denseHll); - ASSERT_EQ(serialize(denseHll), serialize(expectedHll)); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); } } - -INSTANTIATE_TEST_SUITE_P( - SparseHllToDenseTest, - SparseHllToDenseTest, - ::testing::Values(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)); diff --git a/velox/common/memory/ByteStream.cpp b/velox/common/memory/ByteStream.cpp index 4fe93b569216..c910f32c9c90 100644 --- a/velox/common/memory/ByteStream.cpp +++ b/velox/common/memory/ByteStream.cpp @@ -16,19 +16,21 @@ #include "velox/common/memory/ByteStream.h" +#include + namespace facebook::velox { +static ByteRange convByteRange(folly::ByteRange br) { + return {const_cast(br.data()), folly::to_signed(br.size()), 0}; +} + std::vector byteRangesFromIOBuf(folly::IOBuf* iobuf) { if (iobuf == nullptr) { return {}; } std::vector byteRanges; - auto* current = iobuf; - do { - byteRanges.push_back( - {current->writableData(), static_cast(current->length()), 0}); - current = current->next(); - } while (current != iobuf); + auto dst = std::back_inserter(byteRanges); + std::transform(iobuf->begin(), iobuf->end(), dst, convByteRange); return byteRanges; } diff --git a/velox/common/memory/HashStringAllocator.h b/velox/common/memory/HashStringAllocator.h index 253bc4f27a1a..78641f429379 100644 --- a/velox/common/memory/HashStringAllocator.h +++ b/velox/common/memory/HashStringAllocator.h @@ -19,7 +19,6 @@ #include "velox/common/memory/AllocationPool.h" #include "velox/common/memory/ByteStream.h" #include "velox/common/memory/CompactDoubleList.h" -#include "velox/common/memory/Memory.h" #include "velox/common/memory/StreamArena.h" #include "velox/type/StringView.h" @@ -27,6 +26,9 @@ namespace facebook::velox { +template +struct StlAllocator; + /// Implements an arena backed by memory::Allocation. This is for backing /// ByteOutputStream or for allocating single blocks. Blocks can be individually /// freed. Adjacent frees are coalesced and free blocks are kept in a free list. @@ -41,6 +43,9 @@ namespace facebook::velox { /// backing a HashStringAllocator is set to kArenaEnd. class HashStringAllocator : public StreamArena { public: + template + using TStlAllocator = StlAllocator; + /// The minimum allocation must have space after the header for the free list /// pointers and the trailing length. static constexpr int32_t kMinAlloc = @@ -688,8 +693,12 @@ struct StlAllocator { VELOX_CHECK_NOT_NULL(allocator); } + // We can use "explicit" here based on the C++ standard. But + // libstdc++ 12 or older doesn't work for std::vector and + // "explicit". We can avoid it by not using "explicit" here. + // See also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115854 template - explicit StlAllocator(const StlAllocator& allocator) + StlAllocator(const StlAllocator& allocator) : allocator_{allocator.allocator()} { VELOX_CHECK_NOT_NULL(allocator_); } diff --git a/velox/common/memory/MemoryArbitrator.cpp b/velox/common/memory/MemoryArbitrator.cpp index eb817300608c..dc8f1d57a8f4 100644 --- a/velox/common/memory/MemoryArbitrator.cpp +++ b/velox/common/memory/MemoryArbitrator.cpp @@ -99,7 +99,11 @@ class NoopArbitrator : public MemoryArbitrator { } void removePool(MemoryPool* pool) override { - VELOX_CHECK_EQ(pool->reservedBytes(), 0); + VELOX_CHECK_EQ( + pool->reservedBytes(), + 0, + "Memory pool has unexpected reserved bytes on removal: {}", + pool->name()); } // Noop arbitrator has no memory capacity limit so no operation needed for diff --git a/velox/common/memory/MemoryPool.cpp b/velox/common/memory/MemoryPool.cpp index 2a05c4774c5a..2e132f54e58c 100644 --- a/velox/common/memory/MemoryPool.cpp +++ b/velox/common/memory/MemoryPool.cpp @@ -18,6 +18,7 @@ #include +#include "velox/common/Casts.h" #include "velox/common/base/Counters.h" #include "velox/common/base/StatsReporter.h" #include "velox/common/base/SuccinctPrinter.h" @@ -38,14 +39,14 @@ using facebook::velox::common::testutil::TestValue; namespace facebook::velox::memory { namespace { // Check if memory operation is allowed and increment the named stats. -#define CHECK_AND_INC_MEM_OP_STATS(stats) \ +#define CHECK_AND_INC_MEM_OP_STATS(pool, stats) \ do { \ - if (FOLLY_UNLIKELY(kind_ != Kind::kLeaf)) { \ + if (FOLLY_UNLIKELY(pool->kind_ != Kind::kLeaf)) { \ VELOX_FAIL( \ "Memory operation is only allowed on leaf memory pool: {}", \ - toString()); \ + pool->toString()); \ } \ - ++num##stats##_; \ + ++pool->num##stats##_; \ } while (0) // Check if memory operation is allowed and increment the named stats. @@ -153,9 +154,9 @@ std::string capacityToString(int64_t capacity) { return capacity == kMaxMemory ? "UNLIMITED" : succinctBytes(capacity); } -#define DEBUG_RECORD_ALLOC(...) \ - if (FOLLY_UNLIKELY(debugEnabled())) { \ - recordAllocDbg(__VA_ARGS__); \ +#define DEBUG_RECORD_ALLOC(pool, ...) \ + if (FOLLY_UNLIKELY(pool->debugEnabled())) { \ + pool->recordAllocDbg(__VA_ARGS__); \ } #define DEBUG_RECORD_FREE(...) \ if (FOLLY_UNLIKELY(debugEnabled())) { \ @@ -521,7 +522,7 @@ void* MemoryPoolImpl::allocate( } } - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); const auto alignedSize = sizeAlign(size); reserve(alignedSize); void* buffer = allocator_->allocateBytes(alignedSize, alignment_); @@ -534,12 +535,12 @@ void* MemoryPoolImpl::allocate( toString(), allocator_->getAndClearFailureMessage())); } - DEBUG_RECORD_ALLOC(buffer, size); + DEBUG_RECORD_ALLOC(this, buffer, size); return buffer; } void* MemoryPoolImpl::allocateZeroFilled(int64_t numEntries, int64_t sizeEach) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); const auto size = sizeEach * numEntries; const auto alignedSize = sizeAlign(size); reserve(alignedSize); @@ -554,12 +555,12 @@ void* MemoryPoolImpl::allocateZeroFilled(int64_t numEntries, int64_t sizeEach) { toString(), allocator_->getAndClearFailureMessage())); } - DEBUG_RECORD_ALLOC(buffer, size); + DEBUG_RECORD_ALLOC(this, buffer, size); return buffer; } void* MemoryPoolImpl::reallocate(void* p, int64_t size, int64_t newSize) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); const auto alignedNewSize = sizeAlign(newSize); reserve(alignedNewSize); @@ -574,7 +575,7 @@ void* MemoryPoolImpl::reallocate(void* p, int64_t size, int64_t newSize) { toString(), allocator_->getAndClearFailureMessage())); } - DEBUG_RECORD_ALLOC(newP, newSize); + DEBUG_RECORD_ALLOC(this, newP, newSize); if (p != nullptr) { ::memcpy(newP, p, std::min(size, newSize)); free(p, size); @@ -583,18 +584,40 @@ void* MemoryPoolImpl::reallocate(void* p, int64_t size, int64_t newSize) { } void MemoryPoolImpl::free(void* p, int64_t size) { - CHECK_AND_INC_MEM_OP_STATS(Frees); + CHECK_AND_INC_MEM_OP_STATS(this, Frees); const auto alignedSize = sizeAlign(size); DEBUG_RECORD_FREE(p, size); allocator_->freeBytes(p, alignedSize); release(alignedSize); } +bool MemoryPoolImpl::transferTo(MemoryPool* dest, void* buffer, uint64_t size) { + if (!isLeaf() || !dest->isLeaf()) { + return false; + } + VELOX_CHECK_NOT_NULL(dest); + auto* destImpl = checked_pointer_cast(dest); + if (allocator_ != destImpl->allocator_) { + return false; + } + + CHECK_AND_INC_MEM_OP_STATS(destImpl, Allocs); + const auto alignedSize = sizeAlign(size); + destImpl->reserve(alignedSize); + DEBUG_RECORD_ALLOC(destImpl, buffer, size); + + CHECK_AND_INC_MEM_OP_STATS(this, Frees); + DEBUG_RECORD_FREE(buffer, size); + release(alignedSize); + + return true; +} + void MemoryPoolImpl::allocateNonContiguous( MachinePageCount numPages, Allocation& out, MachinePageCount minSizeClass) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); if (!out.empty()) { INC_MEM_OP_STATS(Frees); } @@ -622,14 +645,14 @@ void MemoryPoolImpl::allocateNonContiguous( toString(), allocator_->getAndClearFailureMessage())); } - DEBUG_RECORD_ALLOC(out); + DEBUG_RECORD_ALLOC(this, out); VELOX_CHECK(!out.empty()); VELOX_CHECK_NULL(out.pool()); out.setPool(this); } void MemoryPoolImpl::freeNonContiguous(Allocation& allocation) { - CHECK_AND_INC_MEM_OP_STATS(Frees); + CHECK_AND_INC_MEM_OP_STATS(this, Frees); DEBUG_RECORD_FREE(allocation); const int64_t freedBytes = allocator_->freeNonContiguous(allocation); VELOX_CHECK(allocation.empty()); @@ -648,7 +671,7 @@ void MemoryPoolImpl::allocateContiguous( MachinePageCount numPages, ContiguousAllocation& out, MachinePageCount maxPages) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); if (!out.empty()) { INC_MEM_OP_STATS(Frees); } @@ -674,14 +697,14 @@ void MemoryPoolImpl::allocateContiguous( toString(), allocator_->getAndClearFailureMessage())); } - DEBUG_RECORD_ALLOC(out); + DEBUG_RECORD_ALLOC(this, out); VELOX_CHECK(!out.empty()); VELOX_CHECK_NULL(out.pool()); out.setPool(this); } void MemoryPoolImpl::freeContiguous(ContiguousAllocation& allocation) { - CHECK_AND_INC_MEM_OP_STATS(Frees); + CHECK_AND_INC_MEM_OP_STATS(this, Frees); const int64_t bytesToFree = allocation.size(); DEBUG_RECORD_FREE(allocation); allocator_->freeContiguous(allocation); @@ -775,7 +798,7 @@ std::shared_ptr MemoryPoolImpl::genChild( } bool MemoryPoolImpl::maybeReserve(uint64_t increment) { - CHECK_AND_INC_MEM_OP_STATS(Reserves); + CHECK_AND_INC_MEM_OP_STATS(this, Reserves); TestValue::adjust( "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", this); // TODO: make this a configurable memory pool option. @@ -923,7 +946,7 @@ void MemoryPoolImpl::incrementReservationLocked(uint64_t bytes) { } void MemoryPoolImpl::release() { - CHECK_AND_INC_MEM_OP_STATS(Releases); + CHECK_AND_INC_MEM_OP_STATS(this, Releases); release(0, true); } @@ -980,6 +1003,21 @@ void MemoryPoolImpl::decrementReservation(uint64_t size) noexcept { sanityCheckLocked(); } +std::string MemoryPoolImpl::toString(bool detail) const { + std::string result; + { + std::lock_guard l(mutex_); + result = toStringLocked(); + } + if (detail) { + result += "\n" + treeMemoryUsage(); + } + if (FOLLY_UNLIKELY(debugEnabled())) { + result += "\n" + dumpRecordsDbg(); + } + return result; +} + std::string MemoryPoolImpl::treeMemoryUsage(bool skipEmptyPool) const { if (parent_ != nullptr) { return parent_->treeMemoryUsage(skipEmptyPool); @@ -1225,7 +1263,7 @@ void MemoryPoolImpl::recordAllocDbg(const void* addr, uint64_t size) { succinctBytes(size), succinctBytes(usedBytes), it->second.callStack.toString(), - dumpRecordsDbg()); + dumpRecordsDbgLocked()); } } @@ -1314,7 +1352,7 @@ void MemoryPoolImpl::leakCheckDbg() { dumpRecordsDbg())); } -std::string MemoryPoolImpl::dumpRecordsDbg() { +std::string MemoryPoolImpl::dumpRecordsDbgLocked() const { VELOX_CHECK(debugEnabled()); std::stringstream oss; oss << fmt::format("Found {} allocations:\n", debugAllocRecords_.size()); diff --git a/velox/common/memory/MemoryPool.h b/velox/common/memory/MemoryPool.h index 601f3ed272b6..3bc634de6735 100644 --- a/velox/common/memory/MemoryPool.h +++ b/velox/common/memory/MemoryPool.h @@ -42,6 +42,9 @@ class MemoryManager; constexpr int64_t kMaxMemory = std::numeric_limits::max(); +template +class StlAllocator; + /// This class provides the memory allocation interfaces for a query execution. /// Each query execution entity creates a dedicated memory pool object. The /// memory pool objects from a query are organized as a tree with four levels @@ -91,6 +94,9 @@ constexpr int64_t kMaxMemory = std::numeric_limits::max(); /// also provides memory usage accounting. class MemoryPool : public std::enable_shared_from_this { public: + template + using TStlAllocator = StlAllocator; + /// Defines the kinds of a memory pool. enum class Kind { /// The leaf memory pool is used for memory allocation. User can allocate @@ -246,6 +252,13 @@ class MemoryPool : public std::enable_shared_from_this { /// Frees an allocated buffer. virtual void free(void* p, int64_t size) = 0; + /// Transfer the ownership of memory at 'buffer' for 'size' bytes to the + /// memory pool 'dest'. Returns true if the transfer succeeds. + virtual bool + transferTo(MemoryPool* /*dest*/, void* /*buffer*/, uint64_t /*size*/) { + return false; + } + /// Allocates one or more runs that add up to at least 'numPages', with the /// smallest run being at least 'minSizeClass' pages. 'minSizeClass' must be /// <= the size of the largest size class. The new memory is returned in 'out' @@ -610,6 +623,8 @@ class MemoryPoolImpl : public MemoryPool { void free(void* p, int64_t size) override; + bool transferTo(MemoryPool* dest, void* buffer, uint64_t size) override; + void allocateNonContiguous( MachinePageCount numPages, Allocation& out, @@ -673,17 +688,7 @@ class MemoryPoolImpl : public MemoryPool { void setDestructionCallback(const DestructionCallback& callback); - std::string toString(bool detail = false) const override { - std::string result; - { - std::lock_guard l(mutex_); - result = toStringLocked(); - } - if (detail) { - result += "\n" + treeMemoryUsage(); - } - return result; - } + std::string toString(bool detail = false) const override; /// Detailed debug pool state printout by traversing the pool structure from /// the root memory pool. @@ -1005,7 +1010,12 @@ class MemoryPoolImpl : public MemoryPool { // Dump the recorded call sites of the memory allocations in // 'debugAllocRecords_' to the string. - std::string dumpRecordsDbg(); + std::string dumpRecordsDbgLocked() const; + + std::string dumpRecordsDbg() const { + std::lock_guard l(debugAllocMutex_); + return dumpRecordsDbgLocked(); + } void handleAllocationFailure(const std::string& failureMessage); @@ -1070,7 +1080,7 @@ class MemoryPoolImpl : public MemoryPool { std::atomic_uint64_t numCapacityGrowths_{0}; // Mutex for 'debugAllocRecords_'. - std::mutex debugAllocMutex_; + mutable std::mutex debugAllocMutex_; // Map from address to 'AllocationRecord'. std::unordered_map debugAllocRecords_; @@ -1088,6 +1098,8 @@ class StlAllocator { /* implicit */ StlAllocator(MemoryPool& pool) : pool{pool} {} + explicit StlAllocator(MemoryPool* pool) : pool{*pool} {} + template /* implicit */ StlAllocator(const StlAllocator& a) : pool{a.pool} {} diff --git a/velox/common/memory/SharedArbitrator.cpp b/velox/common/memory/SharedArbitrator.cpp index e726d76222e3..a5a40667b6d3 100644 --- a/velox/common/memory/SharedArbitrator.cpp +++ b/velox/common/memory/SharedArbitrator.cpp @@ -539,14 +539,14 @@ void SharedArbitrator::addPool(const std::shared_ptr& pool) { } void SharedArbitrator::removePool(MemoryPool* pool) { - VELOX_CHECK_EQ(pool->reservedBytes(), 0); + VELOX_CHECK_EQ(pool->reservedBytes(), 0, "{}", pool->name()); const uint64_t freedBytes = shrinkPool(pool, 0); - VELOX_CHECK_EQ(pool->capacity(), 0); + VELOX_CHECK_EQ(pool->capacity(), 0, "{}", pool->name()); freeCapacity(freedBytes); std::unique_lock guard{participantLock_}; const auto ret = participants_.erase(pool->name()); - VELOX_CHECK_EQ(ret, 1); + VELOX_CHECK_EQ(ret, 1, "{}", pool->name()); } std::vector SharedArbitrator::getCandidates( diff --git a/velox/common/memory/tests/MemoryPoolTest.cpp b/velox/common/memory/tests/MemoryPoolTest.cpp index 9b9ea02985ae..6f8c74d56416 100644 --- a/velox/common/memory/tests/MemoryPoolTest.cpp +++ b/velox/common/memory/tests/MemoryPoolTest.cpp @@ -2726,12 +2726,23 @@ TEST(MemoryPoolTest, debugMode) { ->addLeafChild("child"); const auto& allocRecords = std::dynamic_pointer_cast(pool) ->testingDebugAllocRecords(); + std::vector smallAllocs; + smallAllocs.reserve(kNumIterations); for (int32_t i = 0; i < kNumIterations; i++) { smallAllocs.push_back(pool->allocate(kAllocSizes[0])); } EXPECT_EQ(allocRecords.size(), kNumIterations); checkAllocs(allocRecords, kAllocSizes[0]); + + // Check toString() works with debug mode enabled + const auto poolString = pool->toString(); + EXPECT_FALSE(poolString.empty()); + EXPECT_TRUE( + poolString.find( + "======== 100 allocations of 12.50KB total size ========") != + std::string::npos); + for (int32_t i = 0; i < kNumIterations; i++) { pool->free(smallAllocs[i], kAllocSizes[0]); } @@ -4037,6 +4048,213 @@ TEST_P(MemoryPoolTest, allocationWithCoveredCollateral) { pool->freeContiguous(contiguousAllocation); } +TEST_P(MemoryPoolTest, transferTo) { + MemoryManager::Options options; + options.alignment = MemoryAllocator::kMinAlignment; + options.allocatorCapacity = kDefaultCapacity; + setupMemory(options); + auto manager = getMemoryManager(); + + auto largestSizeClass = manager->allocator()->largestSizeClass(); + std::vector pageCounts{ + largestSizeClass, + largestSizeClass + 1, + largestSizeClass / 10, + 1, + largestSizeClass * 2, + largestSizeClass * 3 + 1}; + + auto assertEqualBytes = [](const memory::MemoryPool* pool, + int64_t usedBytes, + int64_t peakBytes, + int64_t reservedBytes) { + EXPECT_EQ(pool->usedBytes(), usedBytes); + EXPECT_EQ(pool->peakBytes(), peakBytes); + EXPECT_EQ(pool->reservedBytes(), reservedBytes); + }; + + auto assertZeroByte = [](const memory::MemoryPool* pool) { + EXPECT_EQ(pool->usedBytes(), 0); + EXPECT_EQ(pool->reservedBytes(), 0); + }; + + auto getMemoryBytes = [](const memory::MemoryPool* pool) { + return std::make_tuple( + pool->usedBytes(), pool->peakBytes(), pool->reservedBytes()); + }; + + auto createPools = [&manager](bool betweenDifferentRoots) { + auto root1 = manager->addRootPool("root1"); + auto root2 = manager->addRootPool("root2"); + std::shared_ptr from; + std::shared_ptr to; + if (betweenDifferentRoots) { + from = root1->addLeafChild("from"); + to = root2->addLeafChild("to"); + } else { + from = root1->addLeafChild("from"); + to = root1->addLeafChild("to"); + } + return std::make_tuple(root1, root2, from, to); + }; + + auto testTransferAllocate = [&assertZeroByte, + &assertEqualBytes, + &getMemoryBytes, + &createPools](bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + const auto kSize = 1024; + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + auto buffer = from->allocate(kSize); + // Transferring between non-leaf pools is not allowed. + EXPECT_FALSE(from->root()->transferTo(to.get(), buffer, kSize)); + EXPECT_FALSE(from->transferTo(to->root(), buffer, kSize)); + + std::tie(usedBytes, peakBytes, reservedBytes) = getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + from->transferTo(to.get(), buffer, kSize); + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + rootPeakBytes *= 2; + } + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + to->free(buffer, kSize); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + auto testTransferAllocateZeroFilled = + [&assertZeroByte, &assertEqualBytes, &getMemoryBytes, &createPools]( + bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + const auto kSize = 1024; + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + auto buffer = from->allocateZeroFilled(8, kSize / 8); + std::tie(usedBytes, peakBytes, reservedBytes) = + getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + from->transferTo(to.get(), buffer, kSize); + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + rootPeakBytes *= 2; + } + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + to->free(buffer, kSize); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + auto testTransferAllocateContiguous = + [&assertZeroByte, &assertEqualBytes, &getMemoryBytes, &createPools]( + uint64_t pageCount, bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + ContiguousAllocation out; + from->allocateContiguous(pageCount, out); + std::tie(usedBytes, peakBytes, reservedBytes) = + getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + from->transferTo(to.get(), out.data(), out.size()); + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + rootPeakBytes *= 2; + } + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + to->freeContiguous(out); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + auto testTransferAllocateNonContiguous = + [&assertZeroByte, &assertEqualBytes, &getMemoryBytes, &createPools]( + uint64_t pageCount, bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + Allocation out; + from->allocateNonContiguous(pageCount, out); + std::tie(usedBytes, peakBytes, reservedBytes) = + getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + for (auto i = 0; i < out.numRuns(); ++i) { + const auto& run = out.runAt(i); + from->transferTo(to.get(), run.data(), run.numBytes()); + } + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + EXPECT_EQ(to->root()->usedBytes(), rootUsedBytes); + // We reserve and release memory run-by-run, so the peak bytes would + // be no greater than twice of the original peak bytes. + EXPECT_LE(to->root()->peakBytes(), rootPeakBytes * 2); + EXPECT_EQ(to->root()->reservedBytes(), rootReservedBytes); + } else { + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + } + to->freeNonContiguous(out); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + // Test transfer between siblings of the same root pool. + testTransferAllocate(false); + testTransferAllocateZeroFilled(false); + for (auto pageCount : pageCounts) { + testTransferAllocateContiguous(pageCount, false); + testTransferAllocateNonContiguous(pageCount, false); + } + + // Test transfer between different root pools. + testTransferAllocate(true); + testTransferAllocateZeroFilled(true); + for (auto pageCount : pageCounts) { + testTransferAllocateContiguous(pageCount, true); + testTransferAllocateNonContiguous(pageCount, true); + } +} + VELOX_INSTANTIATE_TEST_SUITE_P( MemoryPoolTestSuite, MemoryPoolTest, diff --git a/velox/common/memory/tests/SharedArbitratorTest.cpp b/velox/common/memory/tests/SharedArbitratorTest.cpp index f01cc043da5c..7dae474c2966 100644 --- a/velox/common/memory/tests/SharedArbitratorTest.cpp +++ b/velox/common/memory/tests/SharedArbitratorTest.cpp @@ -1424,12 +1424,20 @@ TEST_P(SharedArbitrationTestWithThreadingModes, reserveReleaseCounters) { VELOX_INSTANTIATE_TEST_SUITE_P( SharedArbitrationTest, SharedArbitrationTestWithParallelExecutionModeOnly, - testing::ValuesIn(std::vector{{false}})); + testing::ValuesIn(std::vector{{false}}), + [](const testing::TestParamInfo& info) { + return fmt::format( + "{}", info.param.isSerialExecutionMode ? "serial" : "parallel"); + }); VELOX_INSTANTIATE_TEST_SUITE_P( SharedArbitrationTest, SharedArbitrationTestWithThreadingModes, - testing::ValuesIn(std::vector{{false}, {true}})); + testing::ValuesIn(std::vector{{false}, {true}}), + [](const testing::TestParamInfo& info) { + return fmt::format( + "{}", info.param.isSerialExecutionMode ? "serial" : "parallel"); + }); } // namespace facebook::velox::memory int main(int argc, char** argv) { diff --git a/velox/common/time/CpuWallTimer.h b/velox/common/time/CpuWallTimer.h index 231c15f66c27..8725a918df5d 100644 --- a/velox/common/time/CpuWallTimer.h +++ b/velox/common/time/CpuWallTimer.h @@ -29,6 +29,8 @@ struct CpuWallTiming { uint64_t wallNanos = 0; uint64_t cpuNanos = 0; + auto operator<=>(const CpuWallTiming&) const = default; + void add(const CpuWallTiming& other) { count += other.count; cpuNanos += other.cpuNanos; diff --git a/velox/connectors/Connector.cpp b/velox/connectors/Connector.cpp index e7c72af478ba..ba9d1a9c4154 100644 --- a/velox/connectors/Connector.cpp +++ b/velox/connectors/Connector.cpp @@ -18,12 +18,6 @@ namespace facebook::velox::connector { namespace { -std::unordered_map>& -connectorFactories() { - static std::unordered_map> - factories; - return factories; -} std::unordered_map>& connectors() { static std::unordered_map> connectors; @@ -43,35 +37,6 @@ std::string DataSink::Stats::toString() const { spillStats.toString()); } -bool registerConnectorFactory(std::shared_ptr factory) { - bool ok = - connectorFactories().insert({factory->connectorName(), factory}).second; - VELOX_CHECK( - ok, - "ConnectorFactory with name '{}' is already registered", - factory->connectorName()); - return true; -} - -bool hasConnectorFactory(const std::string& connectorName) { - return connectorFactories().count(connectorName) == 1; -} - -bool unregisterConnectorFactory(const std::string& connectorName) { - auto count = connectorFactories().erase(connectorName); - return count == 1; -} - -std::shared_ptr getConnectorFactory( - const std::string& connectorName) { - auto it = connectorFactories().find(connectorName); - VELOX_CHECK( - it != connectorFactories().end(), - "ConnectorFactory with name '{}' not registered", - connectorName); - return it->second; -} - bool registerConnector(std::shared_ptr connector) { bool ok = connectors().insert({connector->connectorId(), connector}).second; VELOX_CHECK( diff --git a/velox/connectors/Connector.h b/velox/connectors/Connector.h index 10cf82b342c6..08a40acc64fa 100644 --- a/velox/connectors/Connector.h +++ b/velox/connectors/Connector.h @@ -27,6 +27,7 @@ #include "velox/common/file/TokenProvider.h" #include "velox/common/future/VeloxPromise.h" #include "velox/core/ExpressionEvaluator.h" +#include "velox/core/QueryConfig.h" #include "velox/type/Filter.h" #include "velox/vector/ComplexVector.h" @@ -266,7 +267,15 @@ class DataSource { /// Returns the number of input rows processed so far. virtual uint64_t getCompletedRows() = 0; - virtual std::unordered_map runtimeStats() = 0; +#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY + virtual std::unordered_map runtimeStats() { + return {}; + } +#endif + + virtual std::unordered_map getRuntimeStats() { + return {}; + } /// Returns true if 'this' has initiated all the prefetch this will initiate. /// This means that the caller should schedule next splits to prefetch in the @@ -509,6 +518,14 @@ class ConnectorQueryCtx { selectiveNimbleReaderEnabled_ = value; } + core::QueryConfig::RowSizeTrackingMode rowSizeTrackingMode() const { + return rowSizeTrackingEnabled_; + } + + void setRowSizeTrackingMode(core::QueryConfig::RowSizeTrackingMode value) { + rowSizeTrackingEnabled_ = value; + } + std::shared_ptr fsTokenProvider() const { return fsTokenProvider_; } @@ -531,8 +548,85 @@ class ConnectorQueryCtx { const folly::CancellationToken cancellationToken_; const std::shared_ptr fsTokenProvider_; bool selectiveNimbleReaderEnabled_{false}; + core::QueryConfig::RowSizeTrackingMode rowSizeTrackingEnabled_{ + core::QueryConfig::RowSizeTrackingMode::ENABLED_FOR_ALL}; }; +class Connector; + +class ConnectorFactory { + public: + explicit ConnectorFactory(const char* name) : name_(name) {} + + virtual ~ConnectorFactory() = default; + + const std::string& connectorName() const { + return name_; + } + + virtual std::shared_ptr newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor = nullptr, + folly::Executor* cpuExecutor = nullptr) = 0; + + private: + const std::string name_; +}; + +#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY +namespace detail { +inline std::unordered_map>& +connectorFactories() { + static std::unordered_map> + factories; + return factories; +} +} // namespace detail + +/// Adds a factory for creating connectors to the registry using connector +/// name as the key. Throws if factor with the same name is already present. +/// Always returns true. The return value makes it easy to use with +/// FB_ANONYMOUS_VARIABLE. +inline bool registerConnectorFactory( + std::shared_ptr factory) { + bool ok = detail::connectorFactories() + .insert({factory->connectorName(), factory}) + .second; + VELOX_CHECK( + ok, + "ConnectorFactory with name '{}' is already registered", + factory->connectorName()); + return true; +} + +/// Returns true if a connector with the specified name has been registered, +/// false otherwise. +inline bool hasConnectorFactory(const std::string& connectorName) { + return detail::connectorFactories().count(connectorName) == 1; +} + +/// Unregister a connector factory by name. +/// Returns true if a connector with the specified name has been +/// unregistered, false otherwise. +inline bool unregisterConnectorFactory(const std::string& connectorName) { + auto count = detail::connectorFactories().erase(connectorName); + return count == 1; +} + +/// Returns a factory for creating connectors with the specified name. +/// Throws if factory doesn't exist. +inline std::shared_ptr getConnectorFactory( + const std::string& connectorName) { + auto it = detail::connectorFactories().find(connectorName); + VELOX_CHECK( + it != detail::connectorFactories().end(), + "ConnectorFactory with name '{}' not registered", + connectorName); + return it->second; +} +#endif + class Connector { public: explicit Connector( @@ -566,15 +660,9 @@ class Connector { /// ConnectorSplit in addSplit(). If so, TableScan can preload splits /// so that file opening and metadata operations are off the Driver' /// thread. -#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY - virtual bool supportsSplitPreload() { - return false; - } -#else virtual bool supportsSplitPreload() const { return false; } -#endif /// Returns true if the connector supports index lookup, otherwise false. virtual bool supportsIndexLookup() const { @@ -672,46 +760,6 @@ class Connector { trackers_; }; -class ConnectorFactory { - public: - explicit ConnectorFactory(const char* name) : name_(name) {} - - virtual ~ConnectorFactory() = default; - - const std::string& connectorName() const { - return name_; - } - - virtual std::shared_ptr newConnector( - const std::string& id, - std::shared_ptr config, - folly::Executor* ioExecutor = nullptr, - folly::Executor* cpuExecutor = nullptr) = 0; - - private: - const std::string name_; -}; - -/// Adds a factory for creating connectors to the registry using connector -/// name as the key. Throws if factor with the same name is already present. -/// Always returns true. The return value makes it easy to use with -/// FB_ANONYMOUS_VARIABLE. -bool registerConnectorFactory(std::shared_ptr factory); - -/// Returns true if a connector with the specified name has been registered, -/// false otherwise. -bool hasConnectorFactory(const std::string& connectorName); - -/// Unregister a connector factory by name. -/// Returns true if a connector with the specified name has been -/// unregistered, false otherwise. -bool unregisterConnectorFactory(const std::string& connectorName); - -/// Returns a factory for creating connectors with the specified name. -/// Throws if factory doesn't exist. -std::shared_ptr getConnectorFactory( - const std::string& connectorName); - /// Adds connector instance to the registry using connector ID as the key. /// Throws if connector with the same ID is already present. Always returns /// true. The return value makes it easy to use with FB_ANONYMOUS_VARIABLE. diff --git a/velox/connectors/fuzzer/FuzzerConnector.h b/velox/connectors/fuzzer/FuzzerConnector.h index 53e94b5f638a..5b3f9bf74e22 100644 --- a/velox/connectors/fuzzer/FuzzerConnector.h +++ b/velox/connectors/fuzzer/FuzzerConnector.h @@ -77,7 +77,7 @@ class FuzzerDataSource : public DataSource { return completedBytes_; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { // TODO: Which stats do we want to expose here? return {}; } diff --git a/velox/connectors/hive/FileHandle.cpp b/velox/connectors/hive/FileHandle.cpp index 267691cce2ee..413ca86bef46 100644 --- a/velox/connectors/hive/FileHandle.cpp +++ b/velox/connectors/hive/FileHandle.cpp @@ -58,6 +58,7 @@ std::unique_ptr FileHandleGenerator::operator()( options.fileSize = properties->fileSize; options.readRangeHint = properties->readRangeHint; options.extraFileInfo = properties->extraFileInfo; + options.fileReadOps = properties->fileReadOps; } const auto& filename = key.filename; fileHandle->file = filesystems::getFileSystem(filename, properties_) diff --git a/velox/connectors/hive/FileProperties.h b/velox/connectors/hive/FileProperties.h index d3ed9e3cbd6b..a6158e1fec65 100644 --- a/velox/connectors/hive/FileProperties.h +++ b/velox/connectors/hive/FileProperties.h @@ -25,6 +25,7 @@ #pragma once +#include #include namespace facebook::velox { @@ -34,6 +35,7 @@ struct FileProperties { std::optional modificationTime; std::optional readRangeHint{std::nullopt}; std::shared_ptr extraFileInfo{nullptr}; + folly::F14FastMap fileReadOps{}; }; } // namespace facebook::velox diff --git a/velox/connectors/hive/HiveConfig.cpp b/velox/connectors/hive/HiveConfig.cpp index 3463ec767c63..8b354e2ed7d4 100644 --- a/velox/connectors/hive/HiveConfig.cpp +++ b/velox/connectors/hive/HiveConfig.cpp @@ -159,6 +159,15 @@ int32_t HiveConfig::prefetchRowGroups() const { return config_->get(kPrefetchRowGroups, 1); } +size_t HiveConfig::parallelUnitLoadCount( + const config::ConfigBase* session) const { + auto count = session->get( + kParallelUnitLoadCountSession, + config_->get(kParallelUnitLoadCount, 0)); + VELOX_CHECK_LE(count, 100, "parallelUnitLoadCount too large: {}", count); + return count; +} + int32_t HiveConfig::loadQuantum(const config::ConfigBase* session) const { return session->get( kLoadQuantumSession, config_->get(kLoadQuantum, 8 << 20)); @@ -250,4 +259,18 @@ bool HiveConfig::preserveFlatMapsInMemory( config_->get(kPreserveFlatMapsInMemory, false)); } +std::string HiveConfig::user(const config::ConfigBase* session) const { + return session->get(kUser, config_->get(kUser, "")); +} + +std::string HiveConfig::source(const config::ConfigBase* session) const { + return session->get( + kSource, config_->get(kSource, "")); +} + +std::string HiveConfig::schema(const config::ConfigBase* session) const { + return session->get( + kSchema, config_->get(kSchema, "")); +} + } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveConfig.h b/velox/connectors/hive/HiveConfig.h index 6cf3d911af1a..f69da84ff0c1 100644 --- a/velox/connectors/hive/HiveConfig.h +++ b/velox/connectors/hive/HiveConfig.h @@ -75,11 +75,9 @@ class HiveConfig { "hive.gcs.auth.access-token-provider"; /// Maps table field names to file field names using names, not indices. - // TODO: remove hive_orc_use_column_names since it doesn't exist in presto, - // right now this is only used for testing. static constexpr const char* kOrcUseColumnNames = "hive.orc.use-column-names"; static constexpr const char* kOrcUseColumnNamesSession = - "hive_orc_use_column_names"; + "orc_use_column_names"; /// Maps table field names to file field names using names, not indices. static constexpr const char* kParquetUseColumnNames = @@ -146,6 +144,15 @@ class HiveConfig { /// meta data together. Optimization to decrease the small IO requests static constexpr const char* kFilePreloadThreshold = "file-preload-threshold"; + /// When set to be larger than 0, parallel unit loader feature is enabled and + /// it configures how many units (e.g., stripes) we load in parallel. + /// When set to 0, parallel unit loader feature is disabled and on demand unit + /// loader would be used. + static constexpr const char* kParallelUnitLoadCount = + "parallel-unit-load-count"; + static constexpr const char* kParallelUnitLoadCountSession = + "parallel_unit_load_count"; + /// Config used to create write files. This config is provided to underlying /// file system through hive connector and data sink. The config is free form. /// The form should be defined by the underlying file system. @@ -197,6 +204,10 @@ class HiveConfig { static constexpr const char* kPreserveFlatMapsInMemorySession = "hive.preserve_flat_maps_in_memory"; + static constexpr const char* kUser = "user"; + static constexpr const char* kSource = "source"; + static constexpr const char* kSchema = "schema"; + InsertExistingPartitionsBehavior insertExistingPartitionsBehavior( const config::ConfigBase* session) const; @@ -235,6 +246,8 @@ class HiveConfig { int32_t prefetchRowGroups() const; + size_t parallelUnitLoadCount(const config::ConfigBase* session) const; + int32_t loadQuantum(const config::ConfigBase* session) const; int32_t numCacheFileHandles() const; @@ -283,6 +296,15 @@ class HiveConfig { /// converting them to MapVectors. bool preserveFlatMapsInMemory(const config::ConfigBase* session) const; + /// User of the query. Used for storage logging. + std::string user(const config::ConfigBase* session) const; + + /// Source of the query. Used for storage access and logging. + std::string source(const config::ConfigBase* session) const; + + /// Schema of the query. Used for storage logging. + std::string schema(const config::ConfigBase* session) const; + HiveConfig(std::shared_ptr config) { VELOX_CHECK_NOT_NULL( config, "Config is null for HiveConfig initialization"); diff --git a/velox/connectors/hive/HiveConnector.cpp b/velox/connectors/hive/HiveConnector.cpp index e04828e83aaa..950a84d00112 100644 --- a/velox/connectors/hive/HiveConnector.cpp +++ b/velox/connectors/hive/HiveConnector.cpp @@ -20,6 +20,7 @@ #include "velox/connectors/hive/HiveDataSink.h" #include "velox/connectors/hive/HiveDataSource.h" #include "velox/connectors/hive/HivePartitionFunction.h" +#include "velox/connectors/hive/iceberg/IcebergDataSink.h" #include #include @@ -73,6 +74,16 @@ std::unique_ptr HiveConnector::createDataSink( ConnectorInsertTableHandlePtr connectorInsertTableHandle, ConnectorQueryCtx* connectorQueryCtx, CommitStrategy commitStrategy) { + if (auto icebergInsertHandle = + std::dynamic_pointer_cast( + connectorInsertTableHandle)) { + return std::make_unique( + inputType, + icebergInsertHandle, + connectorQueryCtx, + commitStrategy, + hiveConfig_); + } auto hiveInsertHandle = std::dynamic_pointer_cast( connectorInsertTableHandle); diff --git a/velox/connectors/hive/HiveConnector.h b/velox/connectors/hive/HiveConnector.h index c6b91392976e..7a4330b3eeb7 100644 --- a/velox/connectors/hive/HiveConnector.h +++ b/velox/connectors/hive/HiveConnector.h @@ -44,15 +44,9 @@ class HiveConnector : public Connector { const connector::ColumnHandleMap& columnHandles, ConnectorQueryCtx* connectorQueryCtx) override; -#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY - bool supportsSplitPreload() override { - return true; - } -#else bool supportsSplitPreload() const override { return true; } -#endif std::unique_ptr createDataSink( RowTypePtr inputType, diff --git a/velox/connectors/hive/HiveConnectorUtil.cpp b/velox/connectors/hive/HiveConnectorUtil.cpp index e06ee973ec75..2eb93edd858d 100644 --- a/velox/connectors/hive/HiveConnectorUtil.cpp +++ b/velox/connectors/hive/HiveConnectorUtil.cpp @@ -21,6 +21,7 @@ #include "velox/dwio/common/CachedBufferedInput.h" #include "velox/dwio/common/DirectBufferedInput.h" #include "velox/expression/Expr.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" namespace facebook::velox::connector::hive { @@ -612,6 +613,7 @@ void configureRowReaderOptions( const std::shared_ptr& hiveSplit, const std::shared_ptr& hiveConfig, const config::ConfigBase* sessionProperties, + folly::Executor* const ioExecutor, dwio::common::RowReaderOptions& rowReaderOptions) { auto skipRowsIt = tableParameters.find(dwio::common::TableParameter::kSkipHeaderLineCount); @@ -619,6 +621,7 @@ void configureRowReaderOptions( rowReaderOptions.setSkipRows(folly::to(skipRowsIt->second)); } rowReaderOptions.setScanSpec(scanSpec); + rowReaderOptions.setIOExecutor(ioExecutor); rowReaderOptions.setMetadataFilter(std::move(metadataFilter)); rowReaderOptions.setRequestedType(rowType); rowReaderOptions.range(hiveSplit->start, hiveSplit->length); @@ -627,6 +630,13 @@ void configureRowReaderOptions( hiveConfig->readTimestampUnit(sessionProperties))); rowReaderOptions.setPreserveFlatMapsInMemory( hiveConfig->preserveFlatMapsInMemory(sessionProperties)); + rowReaderOptions.setParallelUnitLoadCount( + hiveConfig->parallelUnitLoadCount(sessionProperties)); + // When parallel unit loader is enabled, all units would be loaded by + // ParallelUnitLoader, thus disable eagerFirstStripeLoad. + if (hiveConfig->parallelUnitLoadCount(sessionProperties) > 0) { + rowReaderOptions.setEagerFirstStripeLoad(false); + } } rowReaderOptions.setSerdeParameters(hiveSplit->serdeParameters); } @@ -646,7 +656,7 @@ bool applyPartitionFilter( if (isPartitionDateDaysSinceEpoch) { result = folly::to(partitionValue); } else { - result = DATE()->toDays(static_cast(partitionValue)); + result = DATE()->toDays(partitionValue); } return applyFilter(*filter, result); } @@ -753,7 +763,8 @@ std::unique_ptr createBufferedInput( const ConnectorQueryCtx* connectorQueryCtx, std::shared_ptr ioStats, std::shared_ptr fsStats, - folly::Executor* executor) { + folly::Executor* executor, + const folly::F14FastMap& fileReadOps) { if (connectorQueryCtx->cache()) { return std::make_unique( fileHandle.file, @@ -766,7 +777,8 @@ std::unique_ptr createBufferedInput( ioStats, std::move(fsStats), executor, - readerOpts); + readerOpts, + fileReadOps); } if (readerOpts.fileFormat() == dwio::common::FileFormat::NIMBLE) { // Nimble streams (in case of single chunk) are compressed as whole and need @@ -778,7 +790,10 @@ std::unique_ptr createBufferedInput( readerOpts.memoryPool(), dwio::common::MetricsLog::voidLog(), ioStats.get(), - fsStats.get()); + fsStats.get(), + dwio::common::BufferedInput::kMaxMergeDistance, + std::nullopt, + fileReadOps); } return std::make_unique( fileHandle.file, @@ -790,7 +805,8 @@ std::unique_ptr createBufferedInput( std::move(ioStats), std::move(fsStats), executor, - readerOpts); + readerOpts, + fileReadOps); } namespace { @@ -899,19 +915,25 @@ core::TypedExprPtr extractFiltersFromRemainingFilter( return inner ? replaceInputs(call, {inner}) : nullptr; } - if ((call->name() == "and" && !negated) || - (call->name() == "or" && negated)) { - auto lhs = extractFiltersFromRemainingFilter( - call->inputs()[0], evaluator, negated, filters, sampleRate); - auto rhs = extractFiltersFromRemainingFilter( - call->inputs()[1], evaluator, negated, filters, sampleRate); - if (!lhs) { - return rhs; + if ((call->name() == expression::kAnd && !negated) || + (call->name() == expression::kOr && negated)) { + std::vector args; + args.reserve(call->inputs().size()); + for (const auto& input : call->inputs()) { + if (auto arg = extractFiltersFromRemainingFilter( + input, evaluator, negated, filters, sampleRate)) { + args.push_back(std::move(arg)); + } + // If extractFiltersFromRemainingFilter returns nullptr, it means + // everything in input is converted to filters. + } + if (args.empty()) { + return nullptr; } - if (!rhs) { - return lhs; + if (args.size() == 1) { + return std::move(args[0]); } - return replaceInputs(call, {lhs, rhs}); + return replaceInputs(call, std::move(args)); } if (!negated) { double rate = getPrestoSampleRate(expr, call, evaluator); diff --git a/velox/connectors/hive/HiveConnectorUtil.h b/velox/connectors/hive/HiveConnectorUtil.h index d649b12d0930..68ffbab2bb71 100644 --- a/velox/connectors/hive/HiveConnectorUtil.h +++ b/velox/connectors/hive/HiveConnectorUtil.h @@ -86,6 +86,7 @@ void configureRowReaderOptions( const std::shared_ptr& hiveSplit, const std::shared_ptr& hiveConfig, const config::ConfigBase* sessionProperties, + folly::Executor* ioExecutor, dwio::common::RowReaderOptions& rowReaderOptions); bool testFilters( @@ -105,7 +106,8 @@ std::unique_ptr createBufferedInput( const ConnectorQueryCtx* connectorQueryCtx, std::shared_ptr ioStats, std::shared_ptr fsStats, - folly::Executor* executor); + folly::Executor* executor, + const folly::F14FastMap& fileReadOps = {}); core::TypedExprPtr extractFiltersFromRemainingFilter( const core::TypedExprPtr& expr, diff --git a/velox/connectors/hive/HiveDataSink.cpp b/velox/connectors/hive/HiveDataSink.cpp index 36c797e4eafc..a4f6b1dc94db 100644 --- a/velox/connectors/hive/HiveDataSink.cpp +++ b/velox/connectors/hive/HiveDataSink.cpp @@ -682,7 +682,10 @@ bool HiveDataSink::finish() { std::vector HiveDataSink::close() { setState(State::kClosed); closeInternal(); + return commitMessage(); +} +std::vector HiveDataSink::commitMessage() const { std::vector partitionUpdates; partitionUpdates.reserve(writerInfo_.size()); for (int i = 0; i < writerInfo_.size(); ++i) { @@ -805,10 +808,11 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { options->spillConfig = spillConfig_; } - if (options->nonReclaimableSection == nullptr) { - options->nonReclaimableSection = - writerInfo_.back()->nonReclaimableSectionHolder.get(); - } + // Always set nonReclaimableSection to the current writer's holder. + // Since insertTableHandle_->writerOptions() returns a shared_ptr, we need + // to ensure each writer has its own nonReclaimableSection pointer. + options->nonReclaimableSection = + writerInfo_.back()->nonReclaimableSectionHolder.get(); if (options->memoryReclaimerFactory == nullptr || options->memoryReclaimerFactory() == nullptr) { @@ -845,6 +849,11 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { options); writer = maybeCreateBucketSortWriter(std::move(writer)); writers_.emplace_back(std::move(writer)); + addThreadLocalRuntimeStat( + fmt::format( + "{}WriterCount", + dwio::common::toString(insertTableHandle_->storageFormat())), + RuntimeCounter(1)); // Extends the buffer used for partition rows calculations. partitionSizes_.emplace_back(0); partitionRows_.emplace_back(nullptr); diff --git a/velox/connectors/hive/HiveDataSink.h b/velox/connectors/hive/HiveDataSink.h index 8c305b7595de..788294984967 100644 --- a/velox/connectors/hive/HiveDataSink.h +++ b/velox/connectors/hive/HiveDataSink.h @@ -337,9 +337,11 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { std::string toString() const override; - private: + protected: const std::vector> inputColumns_; const std::shared_ptr locationHandle_; + + private: const dwio::common::FileFormat storageFormat_; const std::shared_ptr bucketProperty_; const std::optional compressionKind_; @@ -544,11 +546,19 @@ class HiveDataSink : public DataSink { bool canReclaim() const; - private: + protected: // Validates the state transition from 'oldState' to 'newState'. void checkStateTransition(State oldState, State newState); void setState(State newState); + // Generates commit messages for all writers containing metadata about written + // files. Creates a JSON object for each writer with partition name, + // file paths, file names, data sizes, and row counts. This metadata is used + // by the coordinator to commit the transaction and update the metastore. + // + // @return Vector of JSON strings, one per writer. + virtual std::vector commitMessage() const; + class WriterReclaimer : public exec::MemoryReclaimer { public: static std::unique_ptr create( diff --git a/velox/connectors/hive/HiveDataSource.cpp b/velox/connectors/hive/HiveDataSource.cpp index d0c59b6392c5..489c45e4dc3b 100644 --- a/velox/connectors/hive/HiveDataSource.cpp +++ b/velox/connectors/hive/HiveDataSource.cpp @@ -418,60 +418,58 @@ void HiveDataSource::addDynamicFilter( } } -std::unordered_map HiveDataSource::runtimeStats() { - auto res = runtimeStats_.toMap(); +std::unordered_map +HiveDataSource::getRuntimeStats() { + auto res = runtimeStats_.toRuntimeMetricMap(); res.insert( - {{"numPrefetch", RuntimeCounter(ioStats_->prefetch().count())}, + {{"numPrefetch", RuntimeMetric(ioStats_->prefetch().count())}, {"prefetchBytes", - RuntimeCounter( + RuntimeMetric( ioStats_->prefetch().sum(), RuntimeCounter::Unit::kBytes)}, {"totalScanTime", - RuntimeCounter( - ioStats_->totalScanTime(), RuntimeCounter::Unit::kNanos)}, + RuntimeMetric(ioStats_->totalScanTime(), RuntimeCounter::Unit::kNanos)}, {Connector::kTotalRemainingFilterTime, - RuntimeCounter( + RuntimeMetric( totalRemainingFilterTime_.load(std::memory_order_relaxed), RuntimeCounter::Unit::kNanos)}, {"ioWaitWallNanos", - RuntimeCounter( + RuntimeMetric( ioStats_->queryThreadIoLatency().sum() * 1000, RuntimeCounter::Unit::kNanos)}, {"maxSingleIoWaitWallNanos", - RuntimeCounter( + RuntimeMetric( ioStats_->queryThreadIoLatency().max() * 1000, RuntimeCounter::Unit::kNanos)}, {"overreadBytes", - RuntimeCounter( + RuntimeMetric( ioStats_->rawOverreadBytes(), RuntimeCounter::Unit::kBytes)}}); if (ioStats_->read().count() > 0) { - res.insert({"numStorageRead", RuntimeCounter(ioStats_->read().count())}); + res.insert({"numStorageRead", RuntimeMetric(ioStats_->read().count())}); res.insert( {"storageReadBytes", - RuntimeCounter(ioStats_->read().sum(), RuntimeCounter::Unit::kBytes)}); + RuntimeMetric(ioStats_->read().sum(), RuntimeCounter::Unit::kBytes)}); } if (ioStats_->ssdRead().count() > 0) { - res.insert({"numLocalRead", RuntimeCounter(ioStats_->ssdRead().count())}); + res.insert({"numLocalRead", RuntimeMetric(ioStats_->ssdRead().count())}); res.insert( {"localReadBytes", - RuntimeCounter( + RuntimeMetric( ioStats_->ssdRead().sum(), RuntimeCounter::Unit::kBytes)}); } if (ioStats_->ramHit().count() > 0) { - res.insert({"numRamRead", RuntimeCounter(ioStats_->ramHit().count())}); + res.insert({"numRamRead", RuntimeMetric(ioStats_->ramHit().count())}); res.insert( {"ramReadBytes", - RuntimeCounter( + RuntimeMetric( ioStats_->ramHit().sum(), RuntimeCounter::Unit::kBytes)}); } if (numBucketConversion_ > 0) { - res.insert({"numBucketConversion", RuntimeCounter(numBucketConversion_)}); + res.insert({"numBucketConversion", RuntimeMetric(numBucketConversion_)}); } const auto fsStats = fsStats_->stats(); for (const auto& storageStats : fsStats) { - res.emplace( - storageStats.first, - RuntimeCounter(storageStats.second.sum, storageStats.second.unit)); + res.emplace(storageStats.first, storageStats.second); } return res; } diff --git a/velox/connectors/hive/HiveDataSource.h b/velox/connectors/hive/HiveDataSource.h index 64aa3d6420bf..87138b3cde48 100644 --- a/velox/connectors/hive/HiveDataSource.h +++ b/velox/connectors/hive/HiveDataSource.h @@ -60,7 +60,7 @@ class HiveDataSource : public DataSource { return completedRows_; } - std::unordered_map runtimeStats() override; + std::unordered_map getRuntimeStats() override; bool allPrefetchIssued() const override { return splitReader_ && splitReader_->allPrefetchIssued(); diff --git a/velox/connectors/hive/HivePartitionFunction.cpp b/velox/connectors/hive/HivePartitionFunction.cpp index d273cc8163e3..1678ed7af503 100644 --- a/velox/connectors/hive/HivePartitionFunction.cpp +++ b/velox/connectors/hive/HivePartitionFunction.cpp @@ -15,6 +15,8 @@ */ #include "velox/connectors/hive/HivePartitionFunction.h" +#include + namespace facebook::velox::connector::hive { namespace { @@ -26,6 +28,46 @@ int32_t hashInt64(int64_t value) { return ((*reinterpret_cast(&value)) >> 32) ^ value; } +template +inline int32_t hashDecimal(T value, uint8_t scale) { + bool isNegative = value < 0; + uint64_t absValue = + isNegative ? -static_cast(value) : static_cast(value); + + uint32_t high = absValue >> 32; + uint32_t low = absValue; + + uint32_t hash = 31 * high + low; + if (isNegative) { + hash = -hash; + } + + return 31 * hash + scale; +} + +// Simulates Hive's hashing function from Hive v1.2.1 +// org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() +// Returns java BigDecimal#hashCode() +template <> +inline int32_t hashDecimal(int128_t value, uint8_t scale) { + uint32_t words[4]; + bool isNegative = value < 0; + uint128_t absValue = isNegative ? -value : value; + words[0] = absValue >> 96; + words[1] = absValue >> 64; + words[2] = absValue >> 32; + words[3] = absValue; + + uint32_t hash = 0; + for (auto i = 0; i < 4; i++) { + hash = 31 * hash + words[i]; + } + if (isNegative) { + hash = -hash; + } + return hash * 31 + scale; +} + #if defined(__has_feature) #if __has_feature(__address_sanitizer__) __attribute__((no_sanitize("integer"))) @@ -114,6 +156,34 @@ void hashPrimitive( const SelectivityVector& rows, bool mix, std::vector& hashes) { + const auto& type = values.base()->type(); + if constexpr (kind == TypeKind::BIGINT || kind == TypeKind::HUGEINT) { + if (type->isDecimal()) { + const auto scale = getDecimalPrecisionScale(*type).second; + if (rows.isAllSelected()) { + vector_size_t numRows = rows.size(); + for (auto i = 0; i < numRows; ++i) { + const uint32_t hash = values.isNullAt(i) + ? 0 + : hashDecimal( + values.valueAt::NativeType>(i), + scale); + mergeHash(mix, hash, hashes[i]); + } + } else { + rows.applyToSelected([&](auto row) INLINE_LAMBDA { + const uint32_t hash = values.isNullAt(row) + ? 0 + : hashDecimal( + values.valueAt::NativeType>(row), + scale); + mergeHash(mix, hash, hashes[row]); + }); + } + return; + } + } + if (rows.isAllSelected()) { // The compiler seems to be a little fickle with optimizations. // Although rows.applyToSelected should do roughly the same thing, doing @@ -210,6 +280,16 @@ void HivePartitionFunction::hashTyped( hashPrimitive(values, rows, mix, hashes); } +template <> +void HivePartitionFunction::hashTyped( + const DecodedVector& values, + const SelectivityVector& rows, + bool mix, + std::vector& hashes, + size_t /* poolIndex */) { + hashPrimitive(values, rows, mix, hashes); +} + template <> void HivePartitionFunction::hashTyped( const DecodedVector& values, @@ -461,7 +541,7 @@ HivePartitionFunction::HivePartitionFunction( std::vector keyChannels, const std::vector& constValues) : numBuckets_{numBuckets}, - bucketToPartition_{bucketToPartition}, + bucketToPartition_{std::move(bucketToPartition)}, keyChannels_{std::move(keyChannels)} { precomputedHashes_.resize(keyChannels_.size()); size_t constChannel{0}; @@ -495,7 +575,7 @@ std::optional HivePartitionFunction::partition( } } - static const int32_t kInt32Max = std::numeric_limits::max(); + static constexpr int32_t kInt32Max = std::numeric_limits::max(); if (bucketToPartition_.empty()) { // NOTE: if bucket to partition mapping is empty, then we do diff --git a/velox/connectors/hive/HivePartitionUtil.cpp b/velox/connectors/hive/HivePartitionUtil.cpp index cb95b916df35..7e76d53503ea 100644 --- a/velox/connectors/hive/HivePartitionUtil.cpp +++ b/velox/connectors/hive/HivePartitionUtil.cpp @@ -27,6 +27,7 @@ namespace facebook::velox::connector::hive { case TypeKind::SMALLINT: \ case TypeKind::INTEGER: \ case TypeKind::BIGINT: \ + case TypeKind::HUGEINT: \ case TypeKind::VARCHAR: \ case TypeKind::VARBINARY: \ case TypeKind::TIMESTAMP: \ @@ -89,6 +90,22 @@ std::pair makePartitionKeyValueString( DATE()->toString( partitionVector->as>()->valueAt(row))); } + if constexpr (Kind == TypeKind::BIGINT || Kind == TypeKind::HUGEINT) { + if (partitionVector->type()->isDecimal()) { + auto [precision, scale] = + getDecimalPrecisionScale(*partitionVector->type()); + const auto maxStringSize = + DecimalUtil::maxStringViewSize(precision, scale); + std::vector maxString(maxStringSize); + const auto size = DecimalUtil::castToString( + partitionVector->as>()->valueAt(row), + scale, + maxStringSize, + maxString.data()); + return std::make_pair(name, std::string(maxString.data(), size)); + } + } + return std::make_pair( name, makePartitionValueString( diff --git a/velox/connectors/hive/SplitReader.cpp b/velox/connectors/hive/SplitReader.cpp index 166a1f449b6a..dc6d296afd59 100644 --- a/velox/connectors/hive/SplitReader.cpp +++ b/velox/connectors/hive/SplitReader.cpp @@ -48,12 +48,24 @@ VectorPtr newConstantFromString( if (isPartitionDateDaysSinceEpoch) { days = folly::to(value.value()); } else { - days = DATE()->toDays(static_cast(value.value())); + days = DATE()->toDays(value.value()); } return std::make_shared>( pool, size, false, type, std::move(days)); } + if constexpr (std::is_same_v || std::is_same_v) { + if (type->isDecimal()) { + auto [precision, scale] = getDecimalPrecisionScale(*type); + T result; + const auto status = DecimalUtil::castFromString( + StringView(value.value()), precision, scale, result); + VELOX_USER_CHECK(status.ok(), status.message()); + return std::make_shared>( + pool, size, false, type, std::move(result)); + } + } + if constexpr (std::is_same_v) { return std::make_shared>( pool, size, false, type, StringView(value.value())); @@ -158,8 +170,9 @@ void SplitReader::configureReaderOptions( void SplitReader::prepareSplit( std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats) { - createReader(); + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps) { + createReader(fileReadOps); if (emptySplit_) { return; } @@ -170,7 +183,7 @@ void SplitReader::prepareSplit( return; } - createRowReader(std::move(metadataFilter), std::move(rowType)); + createRowReader(std::move(metadataFilter), std::move(rowType), std::nullopt); } void SplitReader::setBucketConversion( @@ -281,7 +294,8 @@ std::string SplitReader::toString() const { static_cast(baseRowReader_.get())); } -void SplitReader::createReader() { +void SplitReader::createReader( + const folly::F14FastMap& fileReadOps) { VELOX_CHECK_NE( baseReaderOpts_.fileFormat(), dwio::common::FileFormat::UNKNOWN); @@ -289,11 +303,13 @@ void SplitReader::createReader() { FileHandleKey fileHandleKey{ .filename = hiveSplit_->filePath, .tokenProvider = connectorQueryCtx_->fsTokenProvider()}; + + auto fileProperties = hiveSplit_->properties.value_or(FileProperties{}); + fileProperties.fileReadOps = fileReadOps; + try { fileHandleCachePtr = fileHandleFactory_->generate( - fileHandleKey, - hiveSplit_->properties.has_value() ? &*hiveSplit_->properties : nullptr, - fsStats_ ? fsStats_.get() : nullptr); + fileHandleKey, &fileProperties, fsStats_ ? fsStats_.get() : nullptr); VELOX_CHECK_NOT_NULL(fileHandleCachePtr.get()); } catch (const VeloxRuntimeError& e) { if (e.errorCode() == error_code::kFileNotFound && @@ -318,7 +334,8 @@ void SplitReader::createReader() { connectorQueryCtx_, ioStats_, fsStats_, - ioExecutor_); + ioExecutor_, + fileReadOps); baseReader_ = dwio::common::getReaderFactory(baseReaderOpts_.fileFormat()) ->createReader(std::move(baseFileInput), baseReaderOpts_); @@ -368,7 +385,8 @@ bool SplitReader::checkIfSplitIsEmpty( void SplitReader::createRowReader( std::shared_ptr metadataFilter, - RowTypePtr rowType) { + RowTypePtr rowType, + std::optional rowSizeTrackingEnabled) { VELOX_CHECK_NULL(baseRowReader_); configureRowReaderOptions( hiveTableHandle_->tableParameters(), @@ -378,7 +396,13 @@ void SplitReader::createRowReader( hiveSplit_, hiveConfig_, connectorQueryCtx_->sessionProperties(), + ioExecutor_, baseRowReaderOpts_); + baseRowReaderOpts_.setTrackRowSize( + rowSizeTrackingEnabled.has_value() + ? *rowSizeTrackingEnabled + : connectorQueryCtx_->rowSizeTrackingMode() != + core::QueryConfig::RowSizeTrackingMode::DISABLED); baseRowReader_ = baseReader_->createRowReader(baseRowReaderOpts_); } diff --git a/velox/connectors/hive/SplitReader.h b/velox/connectors/hive/SplitReader.h index 72a42b56b0ba..1ad92073c5cb 100644 --- a/velox/connectors/hive/SplitReader.h +++ b/velox/connectors/hive/SplitReader.h @@ -80,7 +80,8 @@ class SplitReader { /// would be called only once per incoming split virtual void prepareSplit( std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats); + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps = {}); virtual uint64_t next(uint64_t size, VectorPtr& output); @@ -124,7 +125,8 @@ class SplitReader { /// Create the dwio::common::Reader object baseReader_, which will be used to /// read the data file's metadata and schema - void createReader(); + void createReader( + const folly::F14FastMap& fileReadOps = {}); // Adjust the scan spec according to the current split, then return the // adapted row type. @@ -147,7 +149,8 @@ class SplitReader { /// ColumnReaders that will be used to read the data void createRowReader( std::shared_ptr metadataFilter, - RowTypePtr rowType); + RowTypePtr rowType, + std::optional rowSizeTrackingEnabled); const folly::F14FastSet& bucketChannels() const { return bucketChannels_; diff --git a/velox/connectors/hive/iceberg/CMakeLists.txt b/velox/connectors/hive/iceberg/CMakeLists.txt index 329998b5d40d..6b59de5bbbf8 100644 --- a/velox/connectors/hive/iceberg/CMakeLists.txt +++ b/velox/connectors/hive/iceberg/CMakeLists.txt @@ -14,6 +14,7 @@ velox_add_library( velox_hive_iceberg_splitreader + IcebergDataSink.cpp IcebergSplitReader.cpp IcebergSplit.cpp PositionalDeleteFileReader.cpp diff --git a/velox/connectors/hive/iceberg/IcebergDataSink.cpp b/velox/connectors/hive/iceberg/IcebergDataSink.cpp new file mode 100644 index 000000000000..a9268a8b3046 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataSink.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergDataSink.h" +#include "velox/common/base/Fs.h" + +namespace facebook::velox::connector::hive::iceberg { + +IcebergInsertTableHandle::IcebergInsertTableHandle( + std::vector inputColumns, + LocationHandlePtr locationHandle, + dwio::common::FileFormat tableStorageFormat, + std::optional compressionKind, + const std::unordered_map& serdeParameters) + : HiveInsertTableHandle( + std::move(inputColumns), + std::move(locationHandle), + tableStorageFormat, + nullptr, + compressionKind, + serdeParameters, + nullptr, + false, + std::make_shared()) { + VELOX_USER_CHECK( + !inputColumns_.empty(), + "Input columns cannot be empty for Iceberg tables."); + VELOX_USER_CHECK_NOT_NULL( + locationHandle_, "Location handle is required for Iceberg tables."); +} + +IcebergDataSink::IcebergDataSink( + RowTypePtr inputType, + IcebergInsertTableHandlePtr insertTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + const std::shared_ptr& hiveConfig) + : HiveDataSink( + std::move(inputType), + insertTableHandle, + connectorQueryCtx, + commitStrategy, + hiveConfig, + 0, + nullptr) {} + +std::vector IcebergDataSink::commitMessage() const { + std::vector commitTasks; + commitTasks.reserve(writerInfo_.size()); + + for (auto i = 0; i < writerInfo_.size(); ++i) { + const auto& info = writerInfo_.at(i); + VELOX_CHECK_NOT_NULL(info); + // Following metadata (json format) is consumed by Presto CommitTaskData. + // It contains the minimal subset of metadata. + // TODO: Complete metrics is missing now and this could lead to suboptimal + // query plan, will collect full iceberg metrics in following PR. + // clang-format off + folly::dynamic commitData = folly::dynamic::object( + "path", (fs::path(info->writerParameters.writeDirectory()) / + info->writerParameters.writeFileName()).string()) + ("fileSizeInBytes", ioStats_.at(i)->rawBytesWritten()) + ("metrics", + folly::dynamic::object("recordCount", info->numWrittenRows)) + ("partitionSpecJson", 0) + ("fileFormat", "PARQUET") + ("content", "DATA"); + // clang-format on + auto commitDataJson = folly::toJson(commitData); + commitTasks.push_back(commitDataJson); + } + return commitTasks; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataSink.h b/velox/connectors/hive/iceberg/IcebergDataSink.h new file mode 100644 index 000000000000..d8dfcc933cbd --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataSink.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/connectors/hive/HiveDataSink.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Represents a request for Iceberg write. +class IcebergInsertTableHandle final : public HiveInsertTableHandle { + public: + /// @param inputColumns Columns from the table schema to write. + /// The input RowVector must have the same number of columns and matching + /// types in the same order. + /// Column names in the RowVector may differ from those in inputColumns, + /// only position and type must align. All columns present in the input + /// data must be included, mismatches can lead to write failure. + /// @param locationHandle Contains the target location information including: + /// - Base directory path where data files will be written. + /// - File naming scheme and temporary directory paths. + /// @param compressionKind Optional compression to apply to data files. + /// @param serdeParameters Additional serialization/deserialization parameters + /// for the file format. + IcebergInsertTableHandle( + std::vector inputColumns, + LocationHandlePtr locationHandle, + dwio::common::FileFormat tableStorageFormat, + std::optional compressionKind = {}, + const std::unordered_map& serdeParameters = {}); +}; + +using IcebergInsertTableHandlePtr = + std::shared_ptr; + +class IcebergDataSink : public HiveDataSink { + public: + IcebergDataSink( + RowTypePtr inputType, + IcebergInsertTableHandlePtr insertTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + const std::shared_ptr& hiveConfig); + + /// Generates Iceberg-specific commit messages for all writers containing + /// metadata about written files. Creates a JSON object for each writer + /// in the format expected by Presto and Spark for Iceberg tables. + /// + /// Each commit message contains: + /// - path: full file path where data was written. + /// - fileSizeInBytes: raw bytes written to disk. + /// - metrics: object with recordCount (number of rows written). + /// - partitionSpecJson: partition specification. + /// - fileFormat: storage format (e.g., "PARQUET"). + /// - content: file content type ("DATA" for data files). + /// + /// See + /// https://github.com/prestodb/presto/blob/master/presto-iceberg/src/main/java/com/facebook/presto/iceberg/CommitTaskData.java + /// + /// Note: Complete Iceberg metrics are not yet implemented, which results in + /// incomplete manifest files that may lead to suboptimal query planning. + /// + /// @return Vector of JSON strings, one per writer, formatted according to + /// Presto and Spark Iceberg commit protocol. + std::vector commitMessage() const override; +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergSplitReader.cpp b/velox/connectors/hive/iceberg/IcebergSplitReader.cpp index 8b8e9fe6ffa3..0a4c87a063af 100644 --- a/velox/connectors/hive/iceberg/IcebergSplitReader.cpp +++ b/velox/connectors/hive/iceberg/IcebergSplitReader.cpp @@ -54,7 +54,8 @@ IcebergSplitReader::IcebergSplitReader( void IcebergSplitReader::prepareSplit( std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats) { + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps) { createReader(); if (emptySplit_) { return; @@ -66,7 +67,7 @@ void IcebergSplitReader::prepareSplit( return; } - createRowReader(std::move(metadataFilter), std::move(rowType)); + createRowReader(std::move(metadataFilter), std::move(rowType), std::nullopt); std::shared_ptr icebergSplit = std::dynamic_pointer_cast(hiveSplit_); diff --git a/velox/connectors/hive/iceberg/IcebergSplitReader.h b/velox/connectors/hive/iceberg/IcebergSplitReader.h index 4b3c6b901048..331bd1db12ae 100644 --- a/velox/connectors/hive/iceberg/IcebergSplitReader.h +++ b/velox/connectors/hive/iceberg/IcebergSplitReader.h @@ -43,7 +43,9 @@ class IcebergSplitReader : public SplitReader { void prepareSplit( std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats) override; + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps = {}) + override; uint64_t next(uint64_t size, VectorPtr& output) override; diff --git a/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp b/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp index d36550ac66a8..0b4a185e75f9 100644 --- a/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp +++ b/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp @@ -137,6 +137,7 @@ PositionalDeleteFileReader::PositionalDeleteFileReader( deleteSplit_, nullptr, nullptr, + nullptr, deleteRowReaderOpts); deleteRowReader_.reset(); diff --git a/velox/connectors/hive/iceberg/tests/CMakeLists.txt b/velox/connectors/hive/iceberg/tests/CMakeLists.txt index 3e54d5431754..7ca03f9a653d 100644 --- a/velox/connectors/hive/iceberg/tests/CMakeLists.txt +++ b/velox/connectors/hive/iceberg/tests/CMakeLists.txt @@ -56,9 +56,29 @@ if(NOT VELOX_DISABLE_GOOGLETEST) GTest::gtest GTest::gtest_main ) + + add_executable(velox_hive_iceberg_insert_test IcebergInsertTest.cpp IcebergTestBase.cpp Main.cpp) + + add_test(velox_hive_iceberg_insert_test velox_hive_iceberg_insert_test) + + target_link_libraries( + velox_hive_iceberg_insert_test + velox_exec_test_lib + velox_hive_connector + velox_hive_iceberg_splitreader + velox_vector_fuzzer + GTest::gtest + ) + if(VELOX_ENABLE_PARQUET) target_link_libraries(velox_hive_iceberg_test velox_dwio_parquet_reader) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + + target_link_libraries( + velox_hive_iceberg_insert_test + velox_dwio_parquet_reader + velox_dwio_parquet_writer + ) endif() endif() diff --git a/velox/connectors/hive/iceberg/tests/IcebergInsertTest.cpp b/velox/connectors/hive/iceberg/tests/IcebergInsertTest.cpp new file mode 100644 index 000000000000..0e38e5340358 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergInsertTest.cpp @@ -0,0 +1,81 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::velox::connector::hive::iceberg { +namespace { + +class IcebergInsertTest : public test::IcebergTestBase { + protected: + void test(const RowTypePtr& rowType, double nullRatio = 0.0) { + const auto outputDirectory = exec::test::TempDirectoryPath::create(); + const auto dataPath = fmt::format("{}", outputDirectory->getPath()); + constexpr int32_t numBatches = 10; + constexpr int32_t vectorSize = 5'000; + const auto vectors = + createTestData(rowType, numBatches, vectorSize, nullRatio); + auto dataSink = + createIcebergDataSink(rowType, outputDirectory->getPath(), {}); + + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + ASSERT_TRUE(dataSink->finish()); + const auto commitTasks = dataSink->close(); + createDuckDbTable(vectors); + auto splits = createSplitsForDirectory(dataPath); + ASSERT_EQ(splits.size(), commitTasks.size()); + auto plan = exec::test::PlanBuilder().tableScan(rowType).planNode(); + assertQuery(plan, splits, "SELECT * FROM tmp"); + } +}; + +TEST_F(IcebergInsertTest, basic) { + auto rowType = + ROW({"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11"}, + {BIGINT(), + INTEGER(), + SMALLINT(), + BOOLEAN(), + REAL(), + DECIMAL(18, 5), + VARCHAR(), + VARBINARY(), + DATE(), + TIMESTAMP(), + ROW({"id", "name"}, {INTEGER(), VARCHAR()})}); + test(rowType, 0.2); +} + +TEST_F(IcebergInsertTest, mapAndArray) { + auto rowType = + ROW({"c1", "c2"}, {MAP(INTEGER(), VARCHAR()), ARRAY(VARCHAR())}); + test(rowType); +} + +#ifdef VELOX_ENABLE_PARQUET +TEST_F(IcebergInsertTest, bigDecimal) { + auto rowType = ROW({"c1"}, {DECIMAL(38, 5)}); + fileFormat_ = dwio::common::FileFormat::PARQUET; + test(rowType); +} +#endif + +} // namespace +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/IcebergTestBase.cpp b/velox/connectors/hive/iceberg/tests/IcebergTestBase.cpp new file mode 100644 index 000000000000..aee8ecf2ed81 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergTestBase.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" +#include +#include "velox/connectors/hive/iceberg/IcebergSplit.h" + +namespace facebook::velox::connector::hive::iceberg::test { + +void IcebergTestBase::SetUp() { + HiveConnectorTestBase::SetUp(); +#ifdef VELOX_ENABLE_PARQUET + parquet::registerParquetReaderFactory(); + parquet::registerParquetWriterFactory(); +#endif + Type::registerSerDe(); + + connectorSessionProperties_ = std::make_shared( + std::unordered_map(), true); + + connectorConfig_ = + std::make_shared(std::make_shared( + std::unordered_map())); + + setupMemoryPools(); + + fuzzerOptions_.vectorSize = 100; + fuzzerOptions_.nullRatio = 0.1; + fuzzer_ = std::make_unique(fuzzerOptions_, opPool_.get()); +} + +void IcebergTestBase::TearDown() { + fuzzer_.reset(); + connectorQueryCtx_.reset(); + connectorPool_.reset(); + opPool_.reset(); + root_.reset(); + HiveConnectorTestBase::TearDown(); +} + +void IcebergTestBase::setupMemoryPools() { + root_.reset(); + opPool_.reset(); + connectorPool_.reset(); + connectorQueryCtx_.reset(); + + root_ = memory::memoryManager()->addRootPool( + "IcebergTest", 1L << 30, exec::MemoryReclaimer::create()); + opPool_ = root_->addLeafChild("operator"); + connectorPool_ = + root_->addAggregateChild("connector", exec::MemoryReclaimer::create()); + + connectorQueryCtx_ = std::make_unique( + opPool_.get(), + connectorPool_.get(), + connectorSessionProperties_.get(), + nullptr, + common::PrefixSortConfig(), + nullptr, + nullptr, + "query.IcebergTest", + "task.IcebergTest", + "planNodeId.IcebergTest", + 0, + ""); +} + +std::vector IcebergTestBase::createTestData( + RowTypePtr rowType, + int32_t numBatches, + vector_size_t rowsPerBatch, + double nullRatio) { + std::vector vectors; + vectors.reserve(numBatches); + + fuzzerOptions_.nullRatio = nullRatio; + fuzzerOptions_.allowDictionaryVector = false; + fuzzerOptions_.timestampPrecision = + fuzzer::FuzzerTimestampPrecision::kMilliSeconds; + fuzzer_->setOptions(fuzzerOptions_); + + for (auto i = 0; i < numBatches; ++i) { + vectors.push_back(fuzzer_->fuzzRow(rowType, rowsPerBatch, false)); + } + + return vectors; +} + +IcebergInsertTableHandlePtr IcebergTestBase::createIcebergInsertTableHandle( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath) { + std::vector columnHandles; + for (auto i = 0; i < rowType->size(); ++i) { + auto columnName = rowType->nameOf(i); + auto columnType = HiveColumnHandle::ColumnType::kRegular; + columnHandles.push_back(std::make_shared( + columnName, columnType, rowType->childAt(i), rowType->childAt(i))); + } + + auto locationHandle = std::make_shared( + outputDirectoryPath, + outputDirectoryPath, + LocationHandle::TableType::kNew); + + return std::make_shared( + columnHandles, + locationHandle, + fileFormat_, + common::CompressionKind::CompressionKind_ZSTD); +} + +std::shared_ptr IcebergTestBase::createIcebergDataSink( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath, + const std::vector& partitionTransforms) { + auto tableHandle = + createIcebergInsertTableHandle(rowType, outputDirectoryPath); + return std::make_shared( + rowType, + tableHandle, + connectorQueryCtx_.get(), + connector::CommitStrategy::kNoCommit, + connectorConfig_); +} + +std::vector IcebergTestBase::listFiles( + const std::string& dirPath) { + std::vector files; + if (!std::filesystem::exists(dirPath)) { + return files; + } + + for (auto& dirEntry : + std::filesystem::recursive_directory_iterator(dirPath)) { + if (dirEntry.is_regular_file()) { + files.push_back(dirEntry.path().string()); + } + } + return files; +} + +std::vector> +IcebergTestBase::createSplitsForDirectory(const std::string& directory) { + std::vector> splits; + std::unordered_map customSplitInfo; + customSplitInfo["table_format"] = "hive-iceberg"; + + auto files = listFiles(directory); + for (const auto& filePath : files) { + const auto file = filesystems::getFileSystem(filePath, nullptr) + ->openFileForRead(filePath); + splits.push_back(std::make_shared( + exec::test::kHiveConnectorId, + filePath, + fileFormat_, + 0, + file->size(), + std::unordered_map>{}, + std::nullopt, + customSplitInfo, + nullptr, + /*cacheable=*/true, + std::vector())); + } + + return splits; +} + +} // namespace facebook::velox::connector::hive::iceberg::test diff --git a/velox/connectors/hive/iceberg/tests/IcebergTestBase.h b/velox/connectors/hive/iceberg/tests/IcebergTestBase.h new file mode 100644 index 000000000000..cb5cdcd0f7ea --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergTestBase.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/connectors/hive/iceberg/IcebergDataSink.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/vector/fuzzer/VectorFuzzer.h" +#ifdef VELOX_ENABLE_PARQUET +#include "velox/dwio/parquet/RegisterParquetWriter.h" +#include "velox/dwio/parquet/reader/ParquetReader.h" +#endif + +namespace facebook::velox::connector::hive::iceberg::test { + +class IcebergTestBase : public exec::test::HiveConnectorTestBase { + protected: + void SetUp() override; + + void TearDown() override; + + std::vector createTestData( + RowTypePtr rowType, + int32_t numBatches, + vector_size_t rowsPerBatch, + double nullRatio = 0.0); + + std::shared_ptr createIcebergDataSink( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath, + const std::vector& partitionTransforms = {}); + + std::vector> createSplitsForDirectory( + const std::string& directory); + + std::vector listFiles(const std::string& dirPath); + + dwio::common::FileFormat fileFormat_{dwio::common::FileFormat::DWRF}; + + private: + IcebergInsertTableHandlePtr createIcebergInsertTableHandle( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath); + + std::vector listPartitionDirectories( + const std::string& dataPath); + + void setupMemoryPools(); + + std::shared_ptr root_; + std::shared_ptr opPool_; + std::shared_ptr connectorPool_; + std::shared_ptr connectorSessionProperties_; + std::shared_ptr connectorConfig_; + std::unique_ptr connectorQueryCtx_; + VectorFuzzer::Options fuzzerOptions_; + std::unique_ptr fuzzer_; +}; + +} // namespace facebook::velox::connector::hive::iceberg::test diff --git a/velox/connectors/hive/iceberg/tests/Main.cpp b/velox/connectors/hive/iceberg/tests/Main.cpp new file mode 100644 index 000000000000..3c9dd6615055 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/Main.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/process/ThreadDebugInfo.h" + +#include +#include + +// This main is needed for some tests on linux. +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + // Signal handler required for ThreadDebugInfoTest + facebook::velox::process::addDefaultFatalSignalHandler(); + folly::Init init(&argc, &argv, false); + return RUN_ALL_TESTS(); +} diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsPath.h b/velox/connectors/hive/storage_adapters/abfs/AbfsPath.h index 0dbaa0a4eae1..3a4e0d99f097 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsPath.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsPath.h @@ -54,6 +54,11 @@ static constexpr const char* kAzureOAuthAuthType = "OAuth"; static constexpr const char* kAzureSASAuthType = "SAS"; +// For performance, re - use SAS tokens until the expiry is within this number +// of seconds. +static constexpr const char* kAzureSasTokenRenewPeriod = + "fs.azure.sas.token.renew.period.for.streams"; + // Helper class to parse and extract information from a given ABFS path. class AbfsPath { public: diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.cpp index 9747e5e04c4b..575f1f2de572 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.cpp @@ -57,13 +57,15 @@ class AbfsReadFile::Impl { uint64_t offset, uint64_t length, void* buffer, - File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { preadInternal(offset, length, static_cast(buffer)); return {static_cast(buffer), length}; } - std::string pread(uint64_t offset, uint64_t length, File::IoStats* stats) - const { + std::string pread( + uint64_t offset, + uint64_t length, + const FileStorageContext& fileStorageContext) const { std::string result(length, 0); preadInternal(offset, length, result.data()); return result; @@ -72,7 +74,7 @@ class AbfsReadFile::Impl { uint64_t preadv( uint64_t offset, const std::vector>& buffers, - File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { size_t length = 0; auto size = buffers.size(); for (auto& range : buffers) { @@ -94,14 +96,18 @@ class AbfsReadFile::Impl { uint64_t preadv( folly::Range regions, folly::Range iobufs, - File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { size_t length = 0; VELOX_CHECK_EQ(regions.size(), iobufs.size()); for (size_t i = 0; i < regions.size(); ++i) { const auto& region = regions[i]; auto& output = iobufs[i]; output = folly::IOBuf(folly::IOBuf::CREATE, region.length); - pread(region.offset, region.length, output.writableData(), stats); + pread( + region.offset, + region.length, + output.writableData(), + fileStorageContext); output.append(region.length); length += region.length; } @@ -162,29 +168,29 @@ std::string_view AbfsReadFile::pread( uint64_t offset, uint64_t length, void* buffer, - File::IoStats* stats) const { - return impl_->pread(offset, length, buffer, stats); + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, buffer, fileStorageContext); } std::string AbfsReadFile::pread( uint64_t offset, uint64_t length, - File::IoStats* stats) const { - return impl_->pread(offset, length, stats); + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, fileStorageContext); } uint64_t AbfsReadFile::preadv( uint64_t offset, const std::vector>& buffers, - File::IoStats* stats) const { - return impl_->preadv(offset, buffers, stats); + const FileStorageContext& fileStorageContext) const { + return impl_->preadv(offset, buffers, fileStorageContext); } uint64_t AbfsReadFile::preadv( folly::Range regions, folly::Range iobufs, - File::IoStats* stats) const { - return impl_->preadv(regions, iobufs, stats); + const FileStorageContext& fileStorageContext) const { + return impl_->preadv(regions, iobufs, fileStorageContext); } uint64_t AbfsReadFile::size() const { diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h index 942439c06c1e..b682926ad1af 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h @@ -35,22 +35,22 @@ class AbfsReadFile final : public ReadFile { uint64_t offset, uint64_t length, void* buf, - File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; std::string pread( uint64_t offset, uint64_t length, - File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t preadv( folly::Range regions, folly::Range iobufs, - File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t size() const final; diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.cpp new file mode 100644 index 000000000000..5aefc0983867 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" +#include "velox/common/config/Config.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" + +namespace facebook::velox::filesystems { + +std::vector extractCacheKeyFromConfig( + const config::ConfigBase& config) { + std::vector cacheKeys; + constexpr std::string_view authTypePrefix{kAzureAccountAuthType}; + for (const auto& [key, value] : config.rawConfigs()) { + if (key.find(authTypePrefix) == 0) { + // Extract the accountName after "fs.azure.account.auth.type.". + auto remaining = std::string_view(key).substr(authTypePrefix.size() + 1); + auto dot = remaining.find("."); + VELOX_USER_CHECK_NE( + dot, + std::string_view::npos, + "Invalid Azure account auth type key: {}", + key); + cacheKeys.emplace_back(CacheKey{remaining.substr(0, dot), value}); + } + } + return cacheKeys; +} + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h index 925c6f91ece9..1a6cf6e0a0e7 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h @@ -26,6 +26,16 @@ constexpr std::string_view kAbfsScheme{"abfs://"}; constexpr std::string_view kAbfssScheme{"abfss://"}; } // namespace +class ConfigBase; + +struct CacheKey { + const std::string accountName; + const std::string authType; + + CacheKey(std::string_view accountName, std::string_view authType) + : accountName(accountName), authType(authType) {} +}; + inline bool isAbfsFile(const std::string_view filename) { return filename.find(kAbfsScheme) == 0 || filename.find(kAbfssScheme) == 0; } @@ -45,4 +55,7 @@ inline std::string throwStorageExceptionWithOperationDetails( VELOX_FAIL(errMsg); } +std::vector extractCacheKeyFromConfig( + const config::ConfigBase& config); + } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h b/velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h index 291cc8e73a6a..1a1a68f6d87f 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h +++ b/velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h @@ -16,6 +16,8 @@ #pragma once +#include "velox/common/config/Config.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" #include "velox/connectors/hive/storage_adapters/abfs/AzureBlobClient.h" #include "velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h" diff --git a/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt b/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt index e01169bbd964..136db68d1afd 100644 --- a/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt @@ -23,9 +23,11 @@ if(VELOX_ENABLE_ABFS) AbfsFileSystem.cpp AbfsPath.cpp AbfsReadFile.cpp + AbfsUtil.cpp AbfsWriteFile.cpp AzureClientProviderFactories.cpp AzureClientProviderImpl.cpp + DynamicSasTokenClientProvider.cpp ) velox_link_libraries( diff --git a/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.cpp b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.cpp new file mode 100644 index 000000000000..b98ae99d48ed --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h" + +#include + +namespace facebook::velox::filesystems { + +namespace { + +constexpr int64_t kDefaultSasTokenRenewPeriod = 120; // in seconds + +Azure::DateTime getExpiry(const std::string_view& token) { + if (token.empty()) { + return Azure::DateTime::clock::time_point::min(); + } + + static constexpr std::string_view kSignedExpiry{"se="}; + static constexpr int32_t kSignedExpiryLen = 3; + + auto start = token.find(kSignedExpiry); + if (start == std::string::npos) { + return Azure::DateTime::clock::time_point::min(); + } + start += kSignedExpiryLen; + + auto end = token.find("&", start); + auto seValue = (end == std::string::npos) + ? std::string(token.substr(start)) + : std::string(token.substr(start, end - start)); + + seValue = Azure::Core::Url::Decode(seValue); + auto seDate = + Azure::DateTime::Parse(seValue, Azure::DateTime::DateFormat::Rfc3339); + + static constexpr std::string_view kSignedKeyExpiry = "ske="; + static constexpr int32_t kSignedKeyExpiryLen = 4; + + start = token.find(kSignedKeyExpiry); + if (start == std::string::npos) { + return seDate; + } + start += kSignedKeyExpiryLen; + + end = token.find("&", start); + auto skeValue = (end == std::string::npos) + ? std::string(token.substr(start)) + : std::string(token.substr(start, end - start)); + + skeValue = Azure::Core::Url::Decode(skeValue); + auto skeDate = + Azure::DateTime::Parse(skeValue, Azure::DateTime::DateFormat::Rfc3339); + + return std::min(skeDate, seDate); +} + +bool isNearExpiry(Azure::DateTime expiration, int64_t minExpirationInSeconds) { + if (expiration == Azure::DateTime::clock::time_point::min()) { + return true; + } + auto remaining = std::chrono::duration_cast( + expiration - Azure::DateTime::clock::now()) + .count(); + return remaining <= minExpirationInSeconds; +} + +class DynamicSasTokenDataLakeFileClient final : public AzureDataLakeFileClient { + public: + DynamicSasTokenDataLakeFileClient( + const std::shared_ptr& abfsPath, + const std::shared_ptr& sasKeyGenerator, + int64_t sasTokenRenewPeriod) + : abfsPath_(abfsPath), + sasKeyGenerator_(sasKeyGenerator), + sasTokenRenewPeriod_(sasTokenRenewPeriod) {} + + void create() override { + getWriteClient()->Create(); + } + + Azure::Storage::Files::DataLake::Models::PathProperties getProperties() + override { + return getReadClient()->GetProperties().Value; + } + + void append(const uint8_t* buffer, size_t size, uint64_t offset) override { + auto bodyStream = Azure::Core::IO::MemoryBodyStream(buffer, size); + getWriteClient()->Append(bodyStream, offset); + } + + void flush(uint64_t position) override { + getWriteClient()->Flush(position); + } + + void close() override {} + + std::string getUrl() override { + return getWriteClient()->GetUrl(); + } + + private: + std::shared_ptr abfsPath_; + std::shared_ptr sasKeyGenerator_; + int64_t sasTokenRenewPeriod_; + + std::unique_ptr writeClient_{nullptr}; + Azure::DateTime writeSasExpiration_{ + Azure::DateTime::clock::time_point::min()}; + + std::unique_ptr readClient_{nullptr}; + Azure::DateTime readSasExpiration_{Azure::DateTime::clock::time_point::min()}; + + DataLakeFileClient* getWriteClient() { + if (writeClient_ == nullptr || + isNearExpiry(writeSasExpiration_, sasTokenRenewPeriod_)) { + const auto& sas = sasKeyGenerator_->getSasToken( + abfsPath_->fileSystem(), abfsPath_->filePath(), kAbfsWriteOperation); + writeSasExpiration_ = getExpiry(sas); + writeClient_ = std::make_unique( + fmt::format("{}?{}", abfsPath_->getUrl(false), sas)); + } + return writeClient_.get(); + } + + DataLakeFileClient* getReadClient() { + if (readClient_ == nullptr || + isNearExpiry(readSasExpiration_, sasTokenRenewPeriod_)) { + const auto& sas = sasKeyGenerator_->getSasToken( + abfsPath_->fileSystem(), abfsPath_->filePath(), kAbfsReadOperation); + readSasExpiration_ = getExpiry(sas); + readClient_ = std::make_unique( + fmt::format("{}?{}", abfsPath_->getUrl(false), sas)); + } + return readClient_.get(); + } +}; + +class DynamicSasTokenBlobClient : public AzureBlobClient { + public: + DynamicSasTokenBlobClient( + const std::shared_ptr& abfsPath, + const std::shared_ptr& sasTokenProvider, + int64_t sasTokenRenewPeriod) + : abfsPath_(abfsPath), + sasTokenProvider_(sasTokenProvider), + sasTokenRenewPeriod_(sasTokenRenewPeriod) {} + + Azure::Response getProperties() + override { + return getBlobClient()->GetProperties(); + } + + Azure::Response download( + const Azure::Storage::Blobs::DownloadBlobOptions& options) override { + return getBlobClient()->Download(options); + } + + std::string getUrl() override { + return getBlobClient()->GetUrl(); + } + + private: + std::shared_ptr abfsPath_; + std::shared_ptr sasTokenProvider_; + int64_t sasTokenRenewPeriod_; + + std::unique_ptr blobClient_{nullptr}; + Azure::DateTime sasExpiration_{Azure::DateTime::clock::time_point::min()}; + + BlobClient* getBlobClient() { + if (blobClient_ == nullptr || + isNearExpiry(sasExpiration_, sasTokenRenewPeriod_)) { + const auto& sas = sasTokenProvider_->getSasToken( + abfsPath_->fileSystem(), abfsPath_->filePath(), kAbfsReadOperation); + sasExpiration_ = getExpiry(sas); + blobClient_ = std::make_unique( + fmt::format("{}?{}", abfsPath_->getUrl(true), sas)); + } + return blobClient_.get(); + } +}; + +} // namespace + +DynamicSasTokenClientProvider::DynamicSasTokenClientProvider( + const std::shared_ptr& sasTokenProvider) + : AzureClientProvider(), sasTokenProvider_(sasTokenProvider) {} + +void DynamicSasTokenClientProvider::init(const config::ConfigBase& config) { + sasTokenRenewPeriod_ = config.get( + kAzureSasTokenRenewPeriod, kDefaultSasTokenRenewPeriod); +} + +std::unique_ptr +DynamicSasTokenClientProvider::getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(config); + return std::make_unique( + abfsPath, sasTokenProvider_, sasTokenRenewPeriod_); +} + +std::unique_ptr +DynamicSasTokenClientProvider::getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(config); + return std::make_unique( + abfsPath, sasTokenProvider_, sasTokenRenewPeriod_); +} +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h new file mode 100644 index 000000000000..ab1d53f0045b --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/connectors/hive/storage_adapters/abfs/AzureBlobClient.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h" + +namespace facebook::velox::filesystems { + +/// SAS permissions reference: +/// https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas#permissions-for-a-directory-container-or-blob +/// +/// ReadClient uses "read" permission for Download and GetProperties. +/// WriteClient uses "read" permission for GetProperties, and "write" permission +/// for other operations. +static const std::string kAbfsReadOperation{"read"}; +static const std::string kAbfsWriteOperation{"write"}; + +/// Interface for providing SAS tokens for ABFS file system operations. +/// Adapted from the Hadoop Azure implementation: +/// org.apache.hadoop.fs.azurebfs.extensions.SASTokenProvider +class SasTokenProvider { + public: + virtual ~SasTokenProvider() = default; + + virtual std::string getSasToken( + const std::string& fileSystem, + const std::string& path, + const std::string& operation) = 0; +}; + +/// Client provider that dynamically refreshes SAS tokens based on the +/// expiration time of the token. A SasTokenProvider for retrieving SAS tokens +/// must be provided to this class. Example for generating the SAS token can be +/// found in: +/// https://github.com/Azure/azure-sdk-for-cpp/blob/3d917e7c178f0a49b189395a907180084857cc70/sdk/storage/azure-storage-blobs/samples/blob_sas.cpp +class DynamicSasTokenClientProvider : public AzureClientProvider { + public: + explicit DynamicSasTokenClientProvider( + const std::shared_ptr& sasTokenProvider); + + std::unique_ptr getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + std::unique_ptr getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + private: + void init(const config::ConfigBase& config); + + std::shared_ptr sasTokenProvider_; + int64_t sasTokenRenewPeriod_; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp b/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp index dcb46aa05298..a87153928633 100644 --- a/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp @@ -68,40 +68,28 @@ void registerAbfsFileSystem() { void registerAzureClientProvider(const config::ConfigBase& config) { #ifdef VELOX_ENABLE_ABFS - for (const auto& [key, value] : config.rawConfigs()) { - constexpr std::string_view authTypePrefix{kAzureAccountAuthType}; - if (key.find(authTypePrefix) == 0) { - std::string_view skey = key; - // Extract the accountName after "fs.azure.account.auth.type.". - auto remaining = skey.substr(authTypePrefix.size() + 1); - auto dot = remaining.find("."); - VELOX_USER_CHECK_NE( - dot, - std::string_view::npos, - "Invalid Azure account auth type key: {}", - key); - auto accountName = std::string(remaining.substr(0, dot)); - if (value == kAzureSharedKeyAuthType) { - AzureClientProviderFactories::registerFactory( - accountName, [](const std::string&) { - return std::make_unique(); - }); - } else if (value == kAzureOAuthAuthType) { - AzureClientProviderFactories::registerFactory( - accountName, [](const std::string&) { - return std::make_unique(); - }); - } else if (value == kAzureSASAuthType) { - AzureClientProviderFactories::registerFactory( - accountName, [](const std::string&) { - return std::make_unique(); - }); - } else { - VELOX_USER_FAIL( - "Unsupported auth type {}, supported auth types are SharedKey, OAuth and SAS.", - value); - } + for (const auto& [accountName, authType] : + extractCacheKeyFromConfig(config)) { + if (authType == kAzureSharedKeyAuthType) { + AzureClientProviderFactories::registerFactory( + accountName, [](const std::string&) { + return std::make_unique(); + }); + } else if (authType == kAzureOAuthAuthType) { + AzureClientProviderFactories::registerFactory( + accountName, [](const std::string&) { + return std::make_unique(); + }); + } else if (authType == kAzureSASAuthType) { + AzureClientProviderFactories::registerFactory( + accountName, [](const std::string&) { + return std::make_unique(); + }); + } else { + VELOX_USER_FAIL( + "Unsupported auth type {}, supported auth types are SharedKey, OAuth and SAS.", + authType); } } #endif diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt b/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt index 8deb76801b02..c81471db9f68 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt @@ -18,6 +18,7 @@ add_executable( AbfsUtilTest.cpp AzureClientProvidersTest.cpp AzureClientProviderFactoriesTest.cpp + DynamicSasTokenClientProviderTest.cpp AzuriteServer.cpp MockDataLakeFileClient.cpp ) diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/DynamicSasTokenClientProviderTest.cpp b/velox/connectors/hive/storage_adapters/abfs/tests/DynamicSasTokenClientProviderTest.cpp new file mode 100644 index 000000000000..21f81355a2ef --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/tests/DynamicSasTokenClientProviderTest.cpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" +#include "velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" + +#include "gtest/gtest.h" + +#include +#include +#include + +using namespace facebook::velox::filesystems; +using namespace facebook::velox; + +namespace { + +class MyDynamicAbfsSasTokenProvider : public SasTokenProvider { + public: + MyDynamicAbfsSasTokenProvider(int64_t expiration) + : expirationSeconds_(expiration) {} + + std::string getSasToken( + const std::string& fileSystem, + const std::string& path, + const std::string& operation) override { + const auto lastSlash = path.find_last_of("/"); + const auto containerName = path.substr(0, lastSlash); + const auto blobName = path.substr(lastSlash + 1); + + Azure::Storage::Sas::BlobSasBuilder sasBuilder; + sasBuilder.ExpiresOn = Azure::DateTime::clock::now() + + std::chrono::seconds(expirationSeconds_); + sasBuilder.BlobContainerName = containerName; + sasBuilder.BlobName = blobName; + sasBuilder.Resource = Azure::Storage::Sas::BlobSasResource::Blob; + sasBuilder.SetPermissions( + Azure::Storage::Sas::BlobSasPermissions::Read & + Azure::Storage::Sas::BlobSasPermissions::Write); + + std::string sasToken = + sasBuilder.GenerateSasToken(Azure::Storage::StorageSharedKeyCredential( + "test", + "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==")); + + // Remove the leading '?' from the SAS token. + if (sasToken[0] == '?') { + sasToken = sasToken.substr(1); + } + + return sasToken; + } + + private: + int64_t expirationSeconds_; +}; + +} // namespace + +TEST(DynamicSasTokenClientProviderTest, dynamicSasToken) { + { + const std::string account = "account1"; + const config::ConfigBase config( + {{"fs.azure.account.auth.type.account1.dfs.core.windows.net", "SAS"}, + {"fs.azure.sas.token.renew.period.for.streams", "1"}}, + false); + registerAzureClientProviderFactory(account, [](const std::string&) { + auto sasTokenProvider = + std::make_shared(3); + return std::make_unique(sasTokenProvider); + }); + + auto abfsPath = std::make_shared( + fmt::format("abfs://abc@{}.dfs.core.windows.net/file", account)); + auto readClient = + AzureClientProviderFactories::getReadFileClient(abfsPath, config); + auto writeClient = + AzureClientProviderFactories::getWriteFileClient(abfsPath, config); + + auto readUrl = readClient->getUrl(); + auto writeUrl = writeClient->getUrl(); + + // Let the current time pass 3 seconds to ensure the SAS token is expired. + std::this_thread::sleep_for(std::chrono::seconds(3)); // NOLINT + + auto newReadUrl = readClient->getUrl(); + ASSERT_NE(readUrl, newReadUrl); + // The SAS token should be reused. + ASSERT_EQ(newReadUrl, readClient->getUrl()); + + auto newWriteUrl = writeClient->getUrl(); + ASSERT_NE(writeUrl, newWriteUrl); + // The SAS token should be reused. + ASSERT_EQ(newWriteUrl, writeClient->getUrl()); + } + + { + // SAS token expired by setting the renewal period to 120 seconds. + const std::string account = "account2"; + const config::ConfigBase config( + {{"fs.azure.account.auth.type.account2.dfs.core.windows.net", "SAS"}, + {"fs.azure.sas.token.renew.period.for.streams", "120"}}, + false); + registerAzureClientProviderFactory(account, [](const std::string&) { + auto sasTokenProvider = + std::make_shared(60); + return std::make_unique(sasTokenProvider); + }); + + auto abfsPath = std::make_shared( + fmt::format("abfs://abc@{}.dfs.core.windows.net/file", account)); + auto readClient = + AzureClientProviderFactories::getReadFileClient(abfsPath, config); + auto writeClient = + AzureClientProviderFactories::getWriteFileClient(abfsPath, config); + + auto readUrl = readClient->getUrl(); + auto writeUrl = writeClient->getUrl(); + + // Let the current time pass 3 seconds to ensure the timestamp in the SAS + // token is updated. + std::this_thread::sleep_for(std::chrono::seconds(3)); // NOLINT + + // Sas token should be renewed because the time left is less than the + // renewal period. + ASSERT_NE(readUrl, readClient->getUrl()); + ASSERT_NE(writeUrl, writeClient->getUrl()); + } +} diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.cpp b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.cpp index 072a4f7a37f2..e8cf14dbb2f4 100644 --- a/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.cpp @@ -59,7 +59,7 @@ class GcsReadFile::Impl { uint64_t length, void* buffer, std::atomic& bytesRead, - filesystems::File::IoStats* stats = nullptr) const { + const FileStorageContext& fileStorageContext) const { preadInternal(offset, length, static_cast(buffer), bytesRead); return {static_cast(buffer), length}; } @@ -68,7 +68,7 @@ class GcsReadFile::Impl { uint64_t offset, uint64_t length, std::atomic& bytesRead, - filesystems::File::IoStats* stats = nullptr) const { + const FileStorageContext& fileStorageContext) const { std::string result(length, 0); char* position = result.data(); preadInternal(offset, length, position, bytesRead); @@ -79,7 +79,7 @@ class GcsReadFile::Impl { uint64_t offset, const std::vector>& buffers, std::atomic& bytesRead, - filesystems::File::IoStats* stats = nullptr) const { + const FileStorageContext& fileStorageContext) const { // 'buffers' contains Ranges(data, size) with some gaps (data = nullptr) in // between. This call must populate the ranges (except gap ranges) // sequentially starting from 'offset'. If a range pointer is nullptr, the @@ -158,21 +158,21 @@ std::string_view GcsReadFile::pread( uint64_t offset, uint64_t length, void* buffer, - filesystems::File::IoStats* stats) const { - return impl_->pread(offset, length, buffer, bytesRead_, stats); + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, buffer, bytesRead_, fileStorageContext); } std::string GcsReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { - return impl_->pread(offset, length, bytesRead_, stats); + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, bytesRead_, fileStorageContext); } uint64_t GcsReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { - return impl_->preadv(offset, buffers, bytesRead_, stats); + const FileStorageContext& fileStorageContext) const { + return impl_->preadv(offset, buffers, bytesRead_, fileStorageContext); } uint64_t GcsReadFile::size() const { diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h index a3d328996ece..6e79ee34afde 100644 --- a/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h +++ b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h @@ -38,17 +38,17 @@ class GcsReadFile : public ReadFile { uint64_t offset, uint64_t length, void* buffer, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; uint64_t size() const override; diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp index aedd1fe44d33..46522ab2a6f5 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp @@ -56,11 +56,29 @@ class HdfsFileSystem::Impl { } ~Impl() { - LOG(INFO) << "Disconnecting HDFS file system"; - int disconnectResult = driver_->Disconnect(hdfsClient_); - if (disconnectResult != 0) { - LOG(WARNING) << "hdfs disconnect failure in HdfsReadFile close: " - << errno; + if (!closed_) { + LOG(WARNING) + << "The HdfsFileSystem instance is not closed upon destruction. You must explicitly call the close() API before JVM termination to ensure proper disconnection."; + } + } + + // The HdfsFileSystem::Disconnect operation requires the JVM method + // definitions to be loaded within an active JVM process. + // Therefore, it must be invoked before the JVM shuts down. + + // To address this, we’ve introduced a new close() API that performs the + // disconnect operation. Third-party applications can call this close() method + // prior to JVM termination to ensure proper cleanup. + void close() { + if (!closed_) { + LOG(WARNING) << "Disconnecting HDFS file system"; + int disconnectResult = driver_->Disconnect(hdfsClient_); + if (disconnectResult != 0) { + LOG(WARNING) << "hdfs disconnect failure in HdfsReadFile close: " + << errno; + } + + closed_ = true; } } @@ -75,6 +93,7 @@ class HdfsFileSystem::Impl { private: hdfsFS hdfsClient_; filesystems::arrow::io::internal::LibHdfsShim* driver_; + bool closed_ = false; }; HdfsFileSystem::HdfsFileSystem( @@ -109,6 +128,10 @@ std::unique_ptr HdfsFileSystem::openFileForWrite( impl_->hdfsShim(), impl_->hdfsClient(), path); } +void HdfsFileSystem::close() { + impl_->close(); +} + bool HdfsFileSystem::isHdfsFile(const std::string_view filePath) { return (filePath.find(kScheme) == 0) || (filePath.find(kViewfsScheme) == 0); } diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h index cebe40aa890d..9720bb13034b 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h @@ -61,6 +61,8 @@ class HdfsFileSystem : public FileSystem { std::string_view path, const FileOptions& options = {}) override; + void close(); + // Deletes the hdfs files. void remove(std::string_view path) override; diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp index affc1dfd2ede..1d320cda44b6 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp @@ -103,12 +103,19 @@ class HdfsReadFile::Impl { } } - std::string_view pread(uint64_t offset, uint64_t length, void* buf) const { + std::string_view pread( + uint64_t offset, + uint64_t length, + void* buf, + const FileStorageContext& fileStorageContext) const { preadInternal(offset, length, static_cast(buf)); return {static_cast(buf), length}; } - std::string pread(uint64_t offset, uint64_t length) const { + std::string pread( + uint64_t offset, + uint64_t length, + const FileStorageContext& fileStorageContext) const { std::string result(length, 0); char* pos = result.data(); preadInternal(offset, length, pos); @@ -163,15 +170,15 @@ std::string_view HdfsReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { - return pImpl->pread(offset, length, buf); + const FileStorageContext& fileStorageContext) const { + return pImpl->pread(offset, length, buf, fileStorageContext); } std::string HdfsReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { - return pImpl->pread(offset, length); + const FileStorageContext& fileStorageContext) const { + return pImpl->pread(offset, length, fileStorageContext); } uint64_t HdfsReadFile::size() const { diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h index ddd35e511a71..a59b178909c6 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h @@ -38,12 +38,12 @@ class HdfsReadFile final : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t size() const final; diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp b/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp index be668a3133e1..26d43ccb9100 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp @@ -54,12 +54,14 @@ HdfsWriteFile::~HdfsWriteFile() { void HdfsWriteFile::close() { int success = driver_->CloseFile(hdfsClient_, hdfsFile_); + common::testutil::TestValue::adjust( + "facebook::velox::connectors::hive::HdfsWriteFile::close", &success); + hdfsFile_ = nullptr; VELOX_CHECK_EQ( success, 0, "Failed to close hdfs file: {}", driver_->GetLastExceptionRootCause()); - hdfsFile_ = nullptr; } void HdfsWriteFile::flush() { diff --git a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp index 1f23179f0a72..bb4f208c4731 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp @@ -28,44 +28,48 @@ namespace facebook::velox::filesystems { #ifdef VELOX_ENABLE_HDFS std::mutex mtx; +folly::ConcurrentHashMap> + registeredFilesystems; + std::function(std::shared_ptr, std::string_view)> hdfsFileSystemGenerator() { - static auto filesystemGenerator = [](std::shared_ptr - properties, - std::string_view filePath) { - static folly::ConcurrentHashMap> - filesystems; - static folly:: - ConcurrentHashMap> - hdfsInitiationFlags; - HdfsServiceEndpoint endpoint = - HdfsFileSystem::getServiceEndpoint(filePath, properties.get()); - std::string hdfsIdentity = endpoint.identity(); - if (filesystems.find(hdfsIdentity) != filesystems.end()) { - return filesystems[hdfsIdentity]; - } - std::unique_lock lk(mtx, std::defer_lock); - /// If the init flag for a given hdfs identity is not found, - /// create one for init use. It's a singleton. - if (hdfsInitiationFlags.find(hdfsIdentity) == hdfsInitiationFlags.end()) { - lk.lock(); - if (hdfsInitiationFlags.find(hdfsIdentity) == hdfsInitiationFlags.end()) { - std::shared_ptr initiationFlagPtr = - std::make_shared(); - hdfsInitiationFlags.insert(hdfsIdentity, initiationFlagPtr); - } - lk.unlock(); - } - folly::call_once( - *hdfsInitiationFlags[hdfsIdentity].get(), - [&properties, endpoint, hdfsIdentity]() { - auto filesystem = - std::make_shared(properties, endpoint); - filesystems.insert(hdfsIdentity, filesystem); - }); - return filesystems[hdfsIdentity]; - }; + static auto filesystemGenerator = + [](std::shared_ptr properties, + std::string_view filePath) { + static folly:: + ConcurrentHashMap> + hdfsInitiationFlags; + HdfsServiceEndpoint endpoint = + HdfsFileSystem::getServiceEndpoint(filePath, properties.get()); + std::string hdfsIdentity = endpoint.identity(); + if (registeredFilesystems.find(hdfsIdentity) != + registeredFilesystems.end()) { + return registeredFilesystems[hdfsIdentity]; + } + std::unique_lock lk(mtx, std::defer_lock); + /// If the init flag for a given hdfs identity is not found, + /// create one for init use. It's a singleton. + if (hdfsInitiationFlags.find(hdfsIdentity) == + hdfsInitiationFlags.end()) { + lk.lock(); + if (hdfsInitiationFlags.find(hdfsIdentity) == + hdfsInitiationFlags.end()) { + std::shared_ptr initiationFlagPtr = + std::make_shared(); + hdfsInitiationFlags.insert(hdfsIdentity, initiationFlagPtr); + } + lk.unlock(); + } + folly::call_once( + *hdfsInitiationFlags[hdfsIdentity].get(), + [&properties, endpoint, hdfsIdentity]() { + auto filesystem = + std::make_shared(properties, endpoint); + registeredFilesystems.insert(hdfsIdentity, filesystem); + }); + return registeredFilesystems[hdfsIdentity]; + }; return filesystemGenerator; } diff --git a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h index 6f6f0c032bd7..18eef4aca176 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h +++ b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h @@ -16,8 +16,15 @@ #pragma once +#include "folly/concurrency/ConcurrentHashMap.h" + namespace facebook::velox::filesystems { +class HdfsFileSystem; + +extern folly::ConcurrentHashMap> + registeredFilesystems; + // Register the HDFS. void registerHdfsFileSystem(); diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt b/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt index a0ad9d67e99c..87dd4681433f 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt @@ -49,7 +49,7 @@ target_compile_options(velox_hdfs_insert_test PRIVATE -Wno-deprecated-declaratio # velox_hdfs_insert_test and velox_hdfs_file_test two tests can't run in # parallel due to the port conflict in Hadoop NameNode and DataNode. The # namenode port conflict can be resolved using the -nnport configuration in -# hadoop-mapreduce-client-jobclient-3.3.0-tests.jar. However the data node port +# hadoop-mapreduce-client-jobclient-3.3.6-tests.jar. However the data node port # cannot be configured. Therefore, we need to make sure that # velox_hdfs_file_test runs only after velox_hdfs_insert_test has finished. set_tests_properties(velox_hdfs_insert_test PROPERTIES DEPENDS velox_hdfs_file_test) diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp index e5c0883284b9..2fce65c91c55 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp @@ -21,6 +21,7 @@ #include "gtest/gtest.h" #include "velox/common/base/Exceptions.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h" #include "velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h" @@ -72,6 +73,11 @@ class HdfsFileSystemTest : public testing::Test { } static void TearDownTestSuite() { + for (const auto& [_, filesystem] : + facebook::velox::filesystems::registeredFilesystems) { + filesystem->close(); + } + miniCluster->stop(); } @@ -520,3 +526,34 @@ TEST_F(HdfsFileSystemTest, readFailures) { std::string(miniCluster->nameNodePort())); verifyFailures(driver, hdfs); } + +DEBUG_ONLY_TEST_F(HdfsFileSystemTest, writeFilePreventsDoubleClose) { + common::testutil::TestValue::enable(); + + int closeCallCount = 0; + + SCOPED_TESTVALUE_SET( + "facebook::velox::connectors::hive::HdfsWriteFile::close", + std::function([&closeCallCount](int* success) { + ++closeCallCount; + if (closeCallCount == 1) { + *success = -1; + } + })); + + auto writeFile = openFileForWrite("/test_double_close.txt"); + + writeFile->append("test data"); + writeFile->flush(); + + VELOX_ASSERT_THROW(writeFile->close(), "Failed to close hdfs file:"); + + EXPECT_EQ(closeCallCount, 1); + + // Destructor should not call close() again because hdfsFile_ is nullptr + // The closeCallCount should remain 1. + writeFile.reset(); + EXPECT_EQ(closeCallCount, 1); + + common::testutil::TestValue::disable(); +} diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp index 9ec9a1254154..ed2287a7c42d 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp @@ -17,6 +17,7 @@ #include "gtest/gtest.h" #include "velox/common/memory/Memory.h" +#include "velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h" #include "velox/connectors/hive/storage_adapters/test_common/InsertTest.h" @@ -47,6 +48,10 @@ class HdfsInsertTest : public testing::Test, public InsertTest { } void TearDown() override { + for (const auto& [_, filesystem] : + facebook::velox::filesystems::registeredFilesystems) { + filesystem->close(); + } InsertTest::TearDown(); miniCluster->stop(); } diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h index c54ae9589b3e..da07cb341a85 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h @@ -26,7 +26,7 @@ static const std::string kMiniClusterExecutableName{"hadoop"}; static const std::string kHadoopSearchPath{":/usr/local/hadoop/bin"}; static const std::string kJarCommand{"jar"}; static const std::string kMiniclusterJar{ - "/share/hadoop/mapreduce/hadoop-mapreduce-client-jobclient-3.3.0-tests.jar"}; + "/share/hadoop/mapreduce/hadoop-mapreduce-client-jobclient-3.3.6-tests.jar"}; static const std::string kMiniclusterCommand{"minicluster"}; static const std::string kNoMapReduceOption{"-nomr"}; static const std::string kFormatNameNodeOption{"-format"}; diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.cpp b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.cpp index 06d180b19f7a..38d66318f3e2 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.cpp @@ -79,13 +79,15 @@ class S3ReadFile ::Impl { uint64_t offset, uint64_t length, void* buffer, - File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { preadInternal(offset, length, static_cast(buffer)); return {static_cast(buffer), length}; } - std::string pread(uint64_t offset, uint64_t length, File::IoStats* stats) - const { + std::string pread( + uint64_t offset, + uint64_t length, + const FileStorageContext& fileStorageContext) const { std::string result(length, 0); char* position = result.data(); preadInternal(offset, length, position); @@ -95,7 +97,7 @@ class S3ReadFile ::Impl { uint64_t preadv( uint64_t offset, const std::vector>& buffers, - File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { // 'buffers' contains Ranges(data, size) with some gaps (data = nullptr) in // between. This call must populate the ranges (except gap ranges) // sequentially starting from 'offset'. AWS S3 GetObject does not support @@ -183,22 +185,22 @@ std::string_view S3ReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { - return impl_->pread(offset, length, buf, stats); + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, buf, fileStorageContext); } std::string S3ReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { - return impl_->pread(offset, length, stats); + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, fileStorageContext); } uint64_t S3ReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { - return impl_->preadv(offset, buffers, stats); + const FileStorageContext& fileStorageContext) const { + return impl_->preadv(offset, buffers, fileStorageContext); } uint64_t S3ReadFile::size() const { diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h index 0b08ed0ec188..de7eb63f5ada 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h +++ b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h @@ -35,17 +35,17 @@ class S3ReadFile : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t size() const final; diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3Util.h b/velox/connectors/hive/storage_adapters/s3fs/S3Util.h index ab2e25790d04..5b6937c76215 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3Util.h +++ b/velox/connectors/hive/storage_adapters/s3fs/S3Util.h @@ -204,7 +204,7 @@ std::optional parseAWSStandardRegionName( class S3ProxyConfigurationBuilder { public: S3ProxyConfigurationBuilder(const std::string& s3Endpoint) - : s3Endpoint_(s3Endpoint){}; + : s3Endpoint_(s3Endpoint) {} S3ProxyConfigurationBuilder& useSsl(const bool& useSsl) { useSsl_ = useSsl; diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp index aba2ac902363..907acf927838 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include #include @@ -24,8 +25,9 @@ #include -namespace facebook::velox::filesystems { +namespace facebook::velox::filesystems::test { namespace { + class S3TestReporter : public BaseStatsReporter { public: mutable std::mutex m; @@ -40,6 +42,7 @@ class S3TestReporter : public BaseStatsReporter { statTypeMap.clear(); histogramPercentilesMap.clear(); } + void registerMetricExportType(const char* key, StatType statType) const override { statTypeMap[key] = statType; @@ -68,18 +71,6 @@ class S3TestReporter : public BaseStatsReporter { histogramPercentilesMap[key.str()] = pcts; } - void registerQuantileMetricExportType( - const char* /* key */, - const std::vector& /* statTypes */, - const std::vector& /* pcts */, - const std::vector& /* slidingWindowsSeconds */) const override {} - - void registerQuantileMetricExportType( - folly::StringPiece /* key */, - const std::vector& /* statTypes */, - const std::vector& /* pcts */, - const std::vector& /* slidingWindowsSeconds */) const override {} - void addMetricValue(const std::string& key, const size_t value) const override { std::lock_guard l(m); @@ -110,42 +101,6 @@ class S3TestReporter : public BaseStatsReporter { counterMap[key.str()] = std::max(counterMap[key.str()], value); } - void addQuantileMetricValue(const std::string& /* key */, size_t /* value */) - const override {} - - void addQuantileMetricValue(const char* /* key */, size_t /* value */) - const override {} - - void addQuantileMetricValue(folly::StringPiece /* key */, size_t /* value */) - const override {} - - void registerDynamicQuantileMetricExportType( - const char* /* keyPattern */, - const std::vector& /* statTypes */, - const std::vector& /* pcts */, - const std::vector& /* slidingWindowsSeconds */) const override {} - - void registerDynamicQuantileMetricExportType( - folly::StringPiece /* keyPattern */, - const std::vector& /* statTypes */, - const std::vector& /* pcts */, - const std::vector& /* slidingWindowsSeconds */) const override {} - - void addDynamicQuantileMetricValue( - const std::string& /* key */, - folly::Range /* subkeys */, - size_t /* value */) const override {} - - void addDynamicQuantileMetricValue( - const char* /* key */, - folly::Range /* subkeys */, - size_t /* value */) const override {} - - void addDynamicQuantileMetricValue( - folly::StringPiece /* key */, - folly::Range /* subkeys */, - size_t /* value */) const override {} - std::string fetchMetrics() override { std::stringstream ss; ss << "["; @@ -217,7 +172,7 @@ TEST_F(S3FileSystemMetricsTest, metrics) { EXPECT_EQ(1, s3Reporter->counterMap[std::string{kMetricS3GetObjectCalls}]); } -} // namespace facebook::velox::filesystems +} // namespace facebook::velox::filesystems::test int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/velox/connectors/hive/tests/HiveConnectorTest.cpp b/velox/connectors/hive/tests/HiveConnectorTest.cpp index 6fdc8d130c98..c5d69f26d1d1 100644 --- a/velox/connectors/hive/tests/HiveConnectorTest.cpp +++ b/velox/connectors/hive/tests/HiveConnectorTest.cpp @@ -21,6 +21,7 @@ #include "velox/connectors/hive/HiveConfig.h" #include "velox/connectors/hive/HiveConnectorUtil.h" #include "velox/connectors/hive/HiveDataSource.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" namespace facebook::velox::connector::hive { @@ -610,6 +611,41 @@ TEST_F(HiveConnectorTest, extractFiltersFromRemainingFilter) { ASSERT_TRUE(remaining); ASSERT_EQ( remaining->toString(), "not(lt(ROW[\"c2\"],cast(0 as DECIMAL(20, 0))))"); + + // parseExpr gives AND/OR with 2 arguments. We need to construct the node + // manually to have more than 2. + expr = std::make_shared( + BOOLEAN(), + expression::kAnd, + parseExpr("c0 > 0", rowType), + parseExpr("c1 > 0", rowType), + parseExpr("c2 > 0::decimal(20, 0)", rowType)); + filters.clear(); + remaining = extractFiltersFromRemainingFilter( + expr, &evaluator, false, filters, sampleRate); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 3); + ASSERT_TRUE(filters.contains(Subfield("c0"))); + ASSERT_TRUE(filters.contains(Subfield("c1"))); + ASSERT_TRUE(filters.contains(Subfield("c2"))); + ASSERT_FALSE(remaining); + + expr = std::make_shared( + BOOLEAN(), + expression::kAnd, + parseExpr("c0 % 2 = 0", rowType), + parseExpr("c1 % 3 = 0", rowType), + parseExpr("c2 > 0::decimal(20, 0)", rowType)); + filters.clear(); + remaining = extractFiltersFromRemainingFilter( + expr, &evaluator, false, filters, sampleRate); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 1); + ASSERT_TRUE(filters.contains(Subfield("c2"))); + ASSERT_TRUE(remaining); + ASSERT_EQ( + remaining->toString(), + "and(eq(mod(ROW[\"c0\"],2),0),eq(mod(ROW[\"c1\"],3),0))"); } TEST_F(HiveConnectorTest, prestoTableSampling) { diff --git a/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp b/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp index 7a0653ad22a6..2c6dd76f4549 100644 --- a/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp +++ b/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp @@ -352,6 +352,7 @@ TEST_F(HiveConnectorUtilTest, configureSstRowReaderOptions) { /*hiveSplit=*/hiveSplit, /*hiveConfig=*/nullptr, /*sessionProperties=*/nullptr, + /*ioExecutor=*/nullptr, /*rowReaderOptions=*/rowReaderOpts); EXPECT_EQ(rowReaderOpts.serdeParameters(), hiveSplit->serdeParameters); @@ -377,6 +378,7 @@ TEST_F(HiveConnectorUtilTest, configureRowReaderOptionsFromConfig) { /*hiveSplit=*/hiveSplit, /*hiveConfig=*/hiveConfig, /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, /*rowReaderOptions=*/rowReaderOpts); EXPECT_FALSE(rowReaderOpts.preserveFlatMapsInMemory()); @@ -402,6 +404,7 @@ TEST_F(HiveConnectorUtilTest, configureRowReaderOptionsFromConfig) { /*hiveSplit=*/hiveSplit, /*hiveConfig=*/hiveConfig, /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, /*rowReaderOptions=*/rowReaderOpts); EXPECT_TRUE(rowReaderOpts.preserveFlatMapsInMemory()); @@ -428,6 +431,7 @@ TEST_F(HiveConnectorUtilTest, configureRowReaderOptionsFromConfig) { /*hiveSplit=*/hiveSplit, /*hiveConfig=*/hiveConfig, /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, /*rowReaderOptions=*/rowReaderOpts); EXPECT_TRUE(rowReaderOpts.preserveFlatMapsInMemory()); @@ -455,6 +459,7 @@ TEST_F(HiveConnectorUtilTest, configureRowReaderOptionsFromConfig) { /*hiveSplit=*/hiveSplit, /*hiveConfig=*/hiveConfig, /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, /*rowReaderOptions=*/rowReaderOpts); EXPECT_TRUE(rowReaderOpts.preserveFlatMapsInMemory()); diff --git a/velox/connectors/hive/tests/HiveDataSinkTest.cpp b/velox/connectors/hive/tests/HiveDataSinkTest.cpp index 4e3f0a19f4ed..cee5b9de971f 100644 --- a/velox/connectors/hive/tests/HiveDataSinkTest.cpp +++ b/velox/connectors/hive/tests/HiveDataSinkTest.cpp @@ -598,6 +598,67 @@ TEST_F(HiveDataSinkTest, basicBucket) { verifyWrittenData(outputDirectory->getPath(), numBuckets); } +TEST_F(HiveDataSinkTest, decimalPartition) { + const auto outputDirectory = TempDirectoryPath::create(); + + connectorSessionProperties_->set( + HiveConfig::kSortWriterFinishTimeSliceLimitMsSession, "1"); + const auto rowType = + ROW({"c0", "c1", "c2"}, {BIGINT(), DECIMAL(14, 3), DECIMAL(20, 4)}); + auto dataSink = createDataSink( + rowType, + outputDirectory->getPath(), + dwio::common::FileFormat::DWRF, + {"c2"}); + auto stats = dataSink->stats(); + ASSERT_TRUE(stats.empty()) << stats.toString(); + + const auto vector = makeRowVector( + {makeNullableFlatVector({1, 2, std::nullopt, 345}), + makeNullableFlatVector( + {1, 2, std::nullopt, 345}, DECIMAL(14, 3)), + makeFlatVector({1, 340, 234567, -345}, DECIMAL(20, 4))}); + + dataSink->appendData(vector); + while (!dataSink->finish()) { + } + const auto partitions = dataSink->close(); + stats = dataSink->stats(); + ASSERT_FALSE(stats.empty()); + ASSERT_EQ(partitions.size(), vector->size()); + + createDuckDbTable({vector}); + + const auto rootPath = outputDirectory->getPath(); + std::vector> splits; + std::unordered_map> partitionKeys; + auto partitionPath = [&](std::string value) { + partitionKeys["c2"] = value; + auto path = listFiles(rootPath + "/c2=" + value)[0]; + splits.push_back(makeHiveConnectorSplits( + path, 1, dwio::common::FileFormat::DWRF, partitionKeys) + .back()); + }; + partitionPath("0.0001"); + partitionPath("0.0340"); + partitionPath("23.4567"); + partitionPath("-0.0345"); + + ColumnHandleMap assignments = { + {"c0", regularColumn("c0", BIGINT())}, + {"c1", regularColumn("c1", DECIMAL(14, 3))}, + {"c2", partitionKey("c2", DECIMAL(20, 4))}}; + + auto op = PlanBuilder() + .startTableScan() + .outputType(rowType) + .assignments(assignments) + .endTableScan() + .planNode(); + + assertQuery(op, splits, fmt::format("SELECT * FROM tmp")); +} + TEST_F(HiveDataSinkTest, close) { for (bool empty : {true, false}) { SCOPED_TRACE(fmt::format("Data sink is empty: {}", empty)); @@ -1375,6 +1436,49 @@ TEST_F(HiveDataSinkTest, lazyVectorForParquet) { } #endif +// Test to verify that each writer has its own nonReclaimableSection +// pointer when writerOptions is shared. +TEST_F(HiveDataSinkTest, sharedWriterOptionsWithMultipleWriters) { + const auto outputDirectory = TempDirectoryPath::create(); + + const int32_t numBuckets = 3; + auto bucketProperty = std::make_shared( + HiveBucketProperty::Kind::kHiveCompatible, + numBuckets, + std::vector{"c0"}, + std::vector{BIGINT()}, + std::vector>{}); + + // Create shared writer options (this simulates the scenario where + // insertTableHandle_->writerOptions() returns a shared object) + auto sharedWriterOptions = std::make_shared(); + + // Create a data sink with multiple writers (one for each bucket) + auto dataSink = createDataSink( + rowType_, + outputDirectory->getPath(), + dwio::common::FileFormat::DWRF, + {}, + bucketProperty, + sharedWriterOptions); + + const auto vectors = createVectors(200, 3); + + // Write data - this should work without throwing exceptions + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + while (!dataSink->finish()) { + } + const auto partitions = dataSink->close(); + + ASSERT_GT(partitions.size(), 1); + createDuckDbTable(vectors); + verifyWrittenData( + outputDirectory->getPath(), static_cast(partitions.size())); +} + } // namespace } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp b/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp index 4fdf68684b04..bb41a6bc36ab 100644 --- a/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp +++ b/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp @@ -124,6 +124,58 @@ TEST_F(HivePartitionFunctionTest, bigint) { assertPartitionsWithConstChannel(values, 997); } +TEST_F(HivePartitionFunctionTest, shortDecimal) { + auto values = makeNullableFlatVector( + {std::nullopt, + 300'000'000'000, + 123456789, + DecimalUtil::kShortDecimalMin / 100, + DecimalUtil::kShortDecimalMax / 100}, + DECIMAL(18, 2)); + + assertPartitions(values, 1, {0, 0, 0, 0, 0}); + assertPartitions(values, 2, {0, 1, 1, 1, 1}); + assertPartitions(values, 500, {0, 471, 313, 115, 37}); + assertPartitions(values, 997, {0, 681, 6, 982, 502}); + + assertPartitionsWithConstChannel(values, 1); + assertPartitionsWithConstChannel(values, 2); + assertPartitionsWithConstChannel(values, 500); + assertPartitionsWithConstChannel(values, 997); + + values = makeFlatVector( + {123456789, DecimalUtil::kShortDecimalMin, DecimalUtil::kShortDecimalMax}, + DECIMAL(18, 0)); + assertPartitions(values, 500, {311, 236, 412}); +} + +TEST_F(HivePartitionFunctionTest, longDecimal) { + auto values = makeNullableFlatVector( + {std::nullopt, + 300'000'000'000, + HugeInt::parse("12345678901234567891"), + DecimalUtil::kLongDecimalMin / 100, + DecimalUtil::kLongDecimalMax / 100}, + DECIMAL(38, 2)); + + assertPartitions(values, 1, {0, 0, 0, 0, 0}); + assertPartitions(values, 2, {0, 1, 1, 1, 1}); + assertPartitions(values, 500, {0, 471, 99, 49, 103}); + assertPartitions(values, 997, {0, 681, 982, 481, 6}); + + assertPartitionsWithConstChannel(values, 1); + assertPartitionsWithConstChannel(values, 2); + assertPartitionsWithConstChannel(values, 500); + assertPartitionsWithConstChannel(values, 997); + + values = makeNullableFlatVector( + {HugeInt::parse("1234567890123456789112345678"), + DecimalUtil::kLongDecimalMin, + DecimalUtil::kLongDecimalMax}, + DECIMAL(38, 0)); + assertPartitions(values, 997, {51, 835, 645}); +} + TEST_F(HivePartitionFunctionTest, varchar) { auto values = makeNullableFlatVector( {std::nullopt, diff --git a/velox/connectors/hive/tests/HivePartitionUtilTest.cpp b/velox/connectors/hive/tests/HivePartitionUtilTest.cpp index 8598f46742fe..3c1575697872 100644 --- a/velox/connectors/hive/tests/HivePartitionUtilTest.cpp +++ b/velox/connectors/hive/tests/HivePartitionUtilTest.cpp @@ -74,7 +74,9 @@ TEST_F(HivePartitionUtilTest, partitionName) { "flat_bigint_col", "dict_string_col", "const_date_col", - "flat_timestamp_col"}, + "flat_timestamp_col", + "short_decimal_col", + "long_decimal_col"}, {makeFlatVector(std::vector{false}), makeFlatVector(std::vector{10}), makeFlatVector(std::vector{100}), @@ -83,7 +85,10 @@ TEST_F(HivePartitionUtilTest, partitionName) { makeDictionary(std::vector{"str1000"}), makeConstant(10000, 1, DATE()), makeFlatVector( - std::vector{Timestamp::fromMillis(1577836800000)})}); + std::vector{Timestamp::fromMillis(1577836800000)}), + makeConstant(10000, 1, DECIMAL(12, 2)), + makeConstant( + DecimalUtil::kLongDecimalMin / 100, 1, DECIMAL(38, 2))}); std::vector expectedPartitionKeyValues{ "flat_bool_col=false", @@ -93,7 +98,9 @@ TEST_F(HivePartitionUtilTest, partitionName) { "flat_bigint_col=10000", "dict_string_col=str1000", "const_date_col=1997-05-19", - "flat_timestamp_col=2019-12-31 16%3A00%3A00.0"}; + "flat_timestamp_col=2019-12-31 16%3A00%3A00.0", + "short_decimal_col=100.00", + "long_decimal_col=-" + std::string(34, '9') + ".99"}; std::vector partitionChannels; for (auto i = 1; i <= expectedPartitionKeyValues.size(); i++) { @@ -140,7 +147,9 @@ TEST_F(HivePartitionUtilTest, partitionNameForNull) { "flat_bigint_col", "flat_string_col", "const_date_col", - "flat_timestamp_col"}; + "flat_timestamp_col", + "short_decimal_col", + "long_decimal_col"}; RowVectorPtr input = makeRowVector( partitionColumnNames, @@ -151,7 +160,9 @@ TEST_F(HivePartitionUtilTest, partitionNameForNull) { makeNullableFlatVector({std::nullopt}), makeNullableFlatVector({std::nullopt}), makeConstant(std::nullopt, 1, DATE()), - makeNullableFlatVector({std::nullopt})}); + makeNullableFlatVector({std::nullopt}), + makeConstant(std::nullopt, 1, DECIMAL(12, 2)), + makeConstant(std::nullopt, 1, DECIMAL(38, 2))}); for (auto i = 0; i < partitionColumnNames.size(); i++) { std::vector partitionChannels = {(column_index_t)i}; diff --git a/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp b/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp index 271e4599d3f0..565d0325ad54 100644 --- a/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp +++ b/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp @@ -322,8 +322,9 @@ TEST_F(PartitionIdGeneratorTest, supportedPartitionKeyTypes) { INTEGER(), BIGINT(), TIMESTAMP(), + DECIMAL(20, 2), }), - {0, 1, 2, 3, 4, 5, 6, 7}, + {0, 1, 2, 3, 4, 5, 6, 7, 8}, 100, pool(), true); @@ -341,7 +342,9 @@ TEST_F(PartitionIdGeneratorTest, supportedPartitionKeyTypes) { makeNullableFlatVector( {std::nullopt, Timestamp::fromMillis(1639426440001), - Timestamp::fromMillis(1639426440002)})}); + Timestamp::fromMillis(1639426440002)}), + makeNullableFlatVector( + {std::nullopt, 1, DecimalUtil::kLongDecimalMin})}); raw_vector ids; idGenerator.run(input, ids); diff --git a/velox/connectors/tests/ConnectorTest.cpp b/velox/connectors/tests/ConnectorTest.cpp index a58cf8227772..90de53198ad1 100644 --- a/velox/connectors/tests/ConnectorTest.cpp +++ b/velox/connectors/tests/ConnectorTest.cpp @@ -15,15 +15,11 @@ */ #include "velox/connectors/Connector.h" -#include "velox/common/base/tests/GTestUtils.h" #include "velox/common/config/Config.h" #include namespace facebook::velox::connector { - -class ConnectorTest : public testing::Test {}; - namespace { class TestConnector : public connector::Connector { @@ -49,9 +45,7 @@ class TestConnector : public connector::Connector { class TestConnectorFactory : public connector::ConnectorFactory { public: - static constexpr const char* kConnectorFactoryName = "test-factory"; - - TestConnectorFactory() : ConnectorFactory(kConnectorFactoryName) {} + TestConnectorFactory() : ConnectorFactory("test-factory") {} std::shared_ptr newConnector( const std::string& id, @@ -62,39 +56,30 @@ class TestConnectorFactory : public connector::ConnectorFactory { } }; -} // namespace +TEST(ConnectorTest, getAllConnectors) { + TestConnectorFactory factory; -TEST_F(ConnectorTest, getAllConnectors) { - registerConnectorFactory(std::make_shared()); - VELOX_ASSERT_THROW( - registerConnectorFactory(std::make_shared()), - "ConnectorFactory with name 'test-factory' is already registered"); - EXPECT_TRUE(hasConnectorFactory(TestConnectorFactory::kConnectorFactoryName)); const int32_t numConnectors = 10; for (int32_t i = 0; i < numConnectors; i++) { - registerConnector( - getConnectorFactory(TestConnectorFactory::kConnectorFactoryName) - ->newConnector( - fmt::format("connector-{}", i), - std::make_shared( - std::unordered_map()))); + registerConnector(factory.newConnector( + fmt::format("connector-{}", i), + std::make_shared( + std::unordered_map()))); } + const auto& connectors = getAllConnectors(); EXPECT_EQ(connectors.size(), numConnectors); for (int32_t i = 0; i < numConnectors; i++) { EXPECT_EQ(connectors.count(fmt::format("connector-{}", i)), 1); } + for (int32_t i = 0; i < numConnectors; i++) { unregisterConnector(fmt::format("connector-{}", i)); } EXPECT_EQ(getAllConnectors().size(), 0); - EXPECT_TRUE( - unregisterConnectorFactory(TestConnectorFactory::kConnectorFactoryName)); - EXPECT_FALSE( - unregisterConnectorFactory(TestConnectorFactory::kConnectorFactoryName)); } -TEST_F(ConnectorTest, connectorSplit) { +TEST(ConnectorTest, connectorSplit) { { const ConnectorSplit split("test", 100, true); ASSERT_EQ(split.connectorId, "test"); @@ -114,4 +99,5 @@ TEST_F(ConnectorTest, connectorSplit) { "[split: connector id test, weight 50, cacheable false]"); } } +} // namespace } // namespace facebook::velox::connector diff --git a/velox/connectors/tpcds/TpcdsConnector.cpp b/velox/connectors/tpcds/TpcdsConnector.cpp index 5902f49da4c1..b981bda903ad 100644 --- a/velox/connectors/tpcds/TpcdsConnector.cpp +++ b/velox/connectors/tpcds/TpcdsConnector.cpp @@ -65,7 +65,7 @@ TpcdsDataSource::TpcdsDataSource( handle, "ColumnHandle must be an instance of TpcdsColumnHandle " "for '{}' on table '{}'", - handle->name(), + it->second->name(), toTableName(table_)); auto idx = tpcdsTableSchema->getChildIdxIfExists(handle->name()); diff --git a/velox/connectors/tpcds/TpcdsConnector.h b/velox/connectors/tpcds/TpcdsConnector.h index 88ad845ff393..329114d90057 100644 --- a/velox/connectors/tpcds/TpcdsConnector.h +++ b/velox/connectors/tpcds/TpcdsConnector.h @@ -102,7 +102,7 @@ class TpcdsDataSource : public velox::connector::DataSource { return completedBytes_; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { return {}; } diff --git a/velox/connectors/tpcds/TpcdsConnectorSplit.h b/velox/connectors/tpcds/TpcdsConnectorSplit.h index 1f51eb22af61..80fb8863e76a 100644 --- a/velox/connectors/tpcds/TpcdsConnectorSplit.h +++ b/velox/connectors/tpcds/TpcdsConnectorSplit.h @@ -53,8 +53,8 @@ template <> struct fmt::formatter : formatter { auto format( - facebook::velox::connector::tpcds::TpcdsConnectorSplit s, - format_context& ctx) { + facebook::velox::connector::tpcds::TpcdsConnectorSplit const& s, + format_context& ctx) const { return formatter::format(s.toString(), ctx); } }; @@ -64,8 +64,9 @@ struct fmt::formatter< std::shared_ptr> : formatter { auto format( - std::shared_ptr s, - format_context& ctx) { + std::shared_ptr< + facebook::velox::connector::tpcds::TpcdsConnectorSplit> const& s, + format_context& ctx) const { return formatter::format(s->toString(), ctx); } }; diff --git a/velox/connectors/tpch/TpchConnector.h b/velox/connectors/tpch/TpchConnector.h index 5d006490bad8..1ec4aad0a551 100644 --- a/velox/connectors/tpch/TpchConnector.h +++ b/velox/connectors/tpch/TpchConnector.h @@ -117,7 +117,7 @@ class TpchDataSource : public DataSource { return completedBytes_; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { // TODO: Which stats do we want to expose here? return {}; } diff --git a/velox/core/Expressions.cpp b/velox/core/Expressions.cpp index 7ab4c78d0484..be462632b63f 100644 --- a/velox/core/Expressions.cpp +++ b/velox/core/Expressions.cpp @@ -466,7 +466,7 @@ TypedExprPtr FieldAccessTypedExpr::rewriteInputNames( VELOX_CHECK_EQ(1, newInputs.size()); // Only rewrite name if input in InputTypedExpr. Rewrite in other // cases(like dereference) is unsound. - if (!is_instance_of(newInputs[0])) { + if (!newInputs[0]->isInputKind()) { return std::make_shared(type(), newInputs[0], name_); } auto it = mapping.find(name_); diff --git a/velox/core/Expressions.h b/velox/core/Expressions.h index 699351d98bd5..4d406b347b74 100644 --- a/velox/core/Expressions.h +++ b/velox/core/Expressions.h @@ -28,7 +28,7 @@ class InputTypedExpr : public ITypedExpr { : ITypedExpr{ExprKind::kInput, std::move(type)} {} bool operator==(const ITypedExpr& other) const final { - return is_instance_of(&other); + return other.isInputKind(); } std::string toString() const override { @@ -268,7 +268,7 @@ class FieldAccessTypedExpr : public ITypedExpr { FieldAccessTypedExpr(TypePtr type, TypedExprPtr input, std::string name) : ITypedExpr{ExprKind::kFieldAccess, std::move(type), {std::move(input)}}, name_(std::move(name)), - isInputColumn_(is_instance_of(inputs()[0].get())) {} + isInputColumn_(inputs()[0]->isInputKind()) {} const std::string& name() const { return name_; @@ -338,7 +338,7 @@ class DereferenceTypedExpr : public ITypedExpr { index_(index) { // Make sure this isn't being used to access a top level column. VELOX_USER_CHECK( - !is_instance_of(inputs()[0]), + !inputs()[0]->isInputKind(), "DereferenceTypedExpr select a subfeild cannot be used to access a top level column"); } @@ -591,7 +591,7 @@ class TypedExprs { public: /// Returns true if 'expr' is a field access expression. static bool isFieldAccess(const TypedExprPtr& expr) { - return is_instance_of(expr); + return expr->isFieldAccessKind(); } /// Returns 'expr' as FieldAccessTypedExprPtr or null if not field access @@ -602,7 +602,7 @@ class TypedExprs { /// Returns true if 'expr' is a constant expression. static bool isConstant(const TypedExprPtr& expr) { - return is_instance_of(expr); + return expr->isConstantKind(); } /// Returns 'expr' as ConstantTypedExprPtr or null if not a constant @@ -613,7 +613,7 @@ class TypedExprs { /// Returns true if 'expr' is a lambda expression. static bool isLambda(const TypedExprPtr& expr) { - return is_instance_of(expr); + return expr->isLambdaKind(); } /// Returns 'expr' as LambdaTypedExprPtr or null if not a lambda expression. diff --git a/velox/core/PlanConsistencyChecker.cpp b/velox/core/PlanConsistencyChecker.cpp index 89c95c0d6542..dfc50dda9b57 100644 --- a/velox/core/PlanConsistencyChecker.cpp +++ b/velox/core/PlanConsistencyChecker.cpp @@ -24,6 +24,35 @@ class Checker : public PlanNodeVisitor { public: void visit(const AggregationNode& node, PlanNodeVisitorContext& ctx) const override { + const auto& rowType = node.sources().at(0)->outputType(); + for (const auto& expr : node.groupingKeys()) { + checkInputs(expr, rowType); + } + + for (const auto& expr : node.preGroupedKeys()) { + checkInputs(expr, rowType); + } + + for (const auto& aggregate : node.aggregates()) { + checkInputs(aggregate.call, rowType); + + for (const auto& expr : aggregate.sortingKeys) { + checkInputs(expr, rowType); + } + + if (aggregate.mask) { + checkInputs(aggregate.mask, rowType); + } + } + + // Verify that output column names are not empty and unique. + std::unordered_set names; + for (const auto& name : node.outputType()->names()) { + VELOX_USER_CHECK(!name.empty(), "Output column name cannot be empty"); + VELOX_USER_CHECK( + names.insert(name).second, "Duplicate output column: {}", name); + } + visitSources(&node, ctx); } diff --git a/velox/core/PlanNode.cpp b/velox/core/PlanNode.cpp index 2cf53c91020a..38ed6c47a18d 100644 --- a/velox/core/PlanNode.cpp +++ b/velox/core/PlanNode.cpp @@ -15,6 +15,7 @@ */ #include +#include "velox/common/Casts.h" #include "velox/common/encode/Base64.h" #include "velox/core/PlanNode.h" #include "velox/vector/VectorSaver.h" @@ -59,7 +60,11 @@ IndexLookupConditionPtr createIndexJoinCondition( } } // namespace -std::vector deserializeJoinConditions( +/// Deserializes lookup conditions from dynamic object for index lookup joins. +/// These conditions are more complex than simple equality join conditions and +/// can include IN, BETWEEN, and EQUAL conditions that involve both left and +/// right side columns. +std::vector deserializejoinConditions( const folly::dynamic& obj, void* context) { if (obj.count("joinConditions") == 0) { @@ -386,6 +391,10 @@ std::vector deserializeFields( array, context); } +FieldAccessTypedExprPtr deserializeField(const folly::dynamic& obj) { + return ISerializable::deserialize(obj); +} + std::vector deserializeStrings(const folly::dynamic& array) { return ISerializable::deserialize>(array); } @@ -1343,14 +1352,14 @@ UnnestNode::UnnestNode( std::vector unnestVariables, std::vector unnestNames, std::optional ordinalityName, - std::optional emptyUnnestValueName, + std::optional markerName, const PlanNodePtr& source) : PlanNode(id), replicateVariables_{std::move(replicateVariables)}, unnestVariables_{std::move(unnestVariables)}, unnestNames_{std::move(unnestNames)}, ordinalityName_{std::move(ordinalityName)}, - emptyUnnestValueName_(std::move(emptyUnnestValueName)), + markerName_(std::move(markerName)), sources_{source} { // Calculate output type. First come "replicate" columns, followed by // "unnest" columns, followed by an optional ordinality column. @@ -1387,8 +1396,8 @@ UnnestNode::UnnestNode( types.emplace_back(BIGINT()); } - if (emptyUnnestValueName_.has_value()) { - names.emplace_back(emptyUnnestValueName_.value()); + if (markerName_.has_value()) { + names.emplace_back(markerName_.value()); types.emplace_back(BOOLEAN()); } @@ -1408,8 +1417,8 @@ folly::dynamic UnnestNode::serialize() const { if (ordinalityName_.has_value()) { obj["ordinalityName"] = ordinalityName_.value(); } - if (emptyUnnestValueName_.has_value()) { - obj["emptyUnnestValueName"] = emptyUnnestValueName_.value(); + if (markerName_.has_value()) { + obj["markerName"] = markerName_.value(); } return obj; } @@ -1431,9 +1440,9 @@ PlanNodePtr UnnestNode::create(const folly::dynamic& obj, void* context) { if (obj.count("ordinalityName")) { ordinalityName = obj["ordinalityName"].asString(); } - std::optional emptyUnnestValueName = std::nullopt; - if (obj.count("emptyUnnestValueName")) { - emptyUnnestValueName = obj["emptyUnnestValueName"].asString(); + std::optional markerName = std::nullopt; + if (obj.count("markerName")) { + markerName = obj["markerName"].asString(); } return std::make_shared( deserializePlanNodeId(obj), @@ -1441,7 +1450,7 @@ PlanNodePtr UnnestNode::create(const folly::dynamic& obj, void* context) { std::move(unnestVariables), std::move(unnestNames), std::move(ordinalityName), - std::move(emptyUnnestValueName), + std::move(markerName), std::move(source)); } @@ -1730,7 +1739,8 @@ IndexLookupJoinNode::IndexLookupJoinNode( const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + TypedExprPtr filter, + bool hasMarker, PlanNodePtr left, TableScanNodePtr right, RowTypePtr outputType) @@ -1739,13 +1749,13 @@ IndexLookupJoinNode::IndexLookupJoinNode( joinType, leftKeys, rightKeys, - /*filter=*/nullptr, + std::move(filter), std::move(left), right, outputType), lookupSourceNode_(std::move(right)), joinConditions_(joinConditions), - includeMatchColumn_(includeMatchColumn) { + hasMarker_(hasMarker) { VELOX_USER_CHECK( !leftKeys.empty(), "The index lookup join node requires at least one join key"); @@ -1789,7 +1799,7 @@ IndexLookupJoinNode::IndexLookupJoinNode( } auto numOutputColumns = outputType_->size(); - if (includeMatchColumn_) { + if (hasMarker_) { VELOX_USER_CHECK( isLeftJoin(), "Index join match column can only present for {} but not {}", @@ -1818,17 +1828,19 @@ PlanNodePtr IndexLookupJoinNode::create( auto sources = deserializeSources(obj, context); VELOX_CHECK_EQ(2, sources.size()); TableScanNodePtr lookupSource = - std::dynamic_pointer_cast(sources[1]); - VELOX_CHECK_NOT_NULL(lookupSource); + checked_pointer_cast(sources[1]); auto leftKeys = deserializeFields(obj["leftKeys"], context); auto rightKeys = deserializeFields(obj["rightKeys"], context); - VELOX_CHECK_EQ(obj.count("filter"), 0); + TypedExprPtr filter; + if (obj.count("filter")) { + filter = ISerializable::deserialize(obj["filter"], context); + } - auto joinConditions = deserializeJoinConditions(obj, context); + auto joinConditions = deserializejoinConditions(obj, context); - const bool includeMatchColumn = obj["includeMatchColumn"].asBool(); + const bool hasMarker = obj["hasMarker"].asBool(); auto outputType = deserializeRowType(obj["outputType"]); @@ -1838,7 +1850,8 @@ PlanNodePtr IndexLookupJoinNode::create( std::move(leftKeys), std::move(rightKeys), std::move(joinConditions), - includeMatchColumn, + filter, + hasMarker, sources[0], std::move(lookupSource), std::move(outputType)); @@ -1853,7 +1866,10 @@ folly::dynamic IndexLookupJoinNode::serialize() const { } obj["joinConditions"] = std::move(serializedJoins); } - obj["includeMatchColumn"] = includeMatchColumn_; + if (filter_) { + obj["filter"] = filter_->serialize(); + } + obj["hasMarker"] = hasMarker_; return obj; } @@ -1863,14 +1879,15 @@ void IndexLookupJoinNode::addDetails(std::stringstream& stream) const { return; } - std::vector joinConditionStrs; - joinConditionStrs.reserve(joinConditions_.size()); + std::vector joinConditionstrs; + joinConditionstrs.reserve(joinConditions_.size()); for (const auto& joinCondition : joinConditions_) { - joinConditionStrs.push_back(joinCondition->toString()); + joinConditionstrs.push_back(joinCondition->toString()); } - stream << ", joinConditions: [" << folly::join(", ", joinConditionStrs) - << " ], includeMatchColumn: [" - << (includeMatchColumn_ ? "true" : "false") << "]"; + stream << ", joinConditions: [" << folly::join(", ", joinConditionstrs) + << " ], filter: [" + << (filter_ == nullptr ? "null" : filter_->toString()) + << "], hasMarker: [" << (hasMarker_ ? "true" : "false") << "]"; } void IndexLookupJoinNode::accept( @@ -1892,9 +1909,7 @@ bool IndexLookupJoinNode::isSupported(JoinType joinType) { } bool isIndexLookupJoin(const PlanNode* planNode) { - const auto* indexLookupJoin = - dynamic_cast(planNode); - return indexLookupJoin != nullptr; + return is_instance_of(planNode); } // static @@ -2709,11 +2724,6 @@ PlanNodePtr TableWriteNode::create(const folly::dynamic& obj, void* context) { auto columns = deserializeRowType(obj["columns"]); auto columnNames = ISerializable::deserialize>(obj["columnNames"]); - AggregationNodePtr aggregationNode; - if (obj.count("aggregationNode") != 0) { - aggregationNode = ISerializable::deserialize( - obj["aggregationNode"], context); - } auto connectorId = obj["connectorId"].asString(); auto connectorInsertTableHandle = ISerializable::deserialize( @@ -3069,21 +3079,30 @@ SpatialJoinNode::SpatialJoinNode( const PlanNodeId& id, JoinType joinType, TypedExprPtr joinCondition, + FieldAccessTypedExprPtr probeGeometry, + FieldAccessTypedExprPtr buildGeometry, + std::optional radius, PlanNodePtr left, PlanNodePtr right, RowTypePtr outputType) : PlanNode(id), joinType_(joinType), joinCondition_(std::move(joinCondition)), + probeGeometry_(std::move(probeGeometry)), + buildGeometry_(std::move(buildGeometry)), + radius_(std::move(radius)), sources_({std::move(left), std::move(right)}), outputType_(std::move(outputType)) { VELOX_USER_CHECK( isSupported(joinType_), "The join type is not supported by spatial join: {}", JoinTypeName::toName(joinType_)); - VELOX_USER_CHECK( - joinCondition_ != nullptr, - "The join condition must not be null for spatial join"); + VELOX_USER_CHECK_NOT_NULL( + joinCondition_, "The join condition must not be null for spatial join"); + VELOX_USER_CHECK_NOT_NULL( + probeGeometry_, "Probe geometery must not be null for spatial joins"); + VELOX_USER_CHECK_NOT_NULL( + buildGeometry_, "Build geometery must not be null for spatial joins"); VELOX_USER_CHECK_EQ( sources_.size(), 2, "Must have 2 sources for spatial joins"); VELOX_USER_CHECK( @@ -3115,6 +3134,11 @@ void SpatialJoinNode::addDetails(std::stringstream& stream) const { if (joinCondition_) { stream << ", joinCondition: " << joinCondition_->toString(); } + stream << ", probeGeometry: " << probeGeometry_->name(); + stream << ", buildGeometry: " << buildGeometry_->name(); + if (radius_) { + stream << ", radius: " << radius_.value()->name(); + } } folly::dynamic SpatialJoinNode::serialize() const { @@ -3124,6 +3148,11 @@ folly::dynamic SpatialJoinNode::serialize() const { obj["joinCondition"] = joinCondition_->serialize(); } obj["outputType"] = outputType_->serialize(); + obj["probeGeometry"] = probeGeometry_->serialize(); + obj["buildGeometry"] = buildGeometry_->serialize(); + if (radius_) { + obj["radius"] = radius_.value()->serialize(); + } return obj; } @@ -3144,11 +3173,20 @@ PlanNodePtr SpatialJoinNode::create(const folly::dynamic& obj, void* context) { } auto outputType = deserializeRowType(obj["outputType"]); + auto probeGeometry = deserializeField(obj["probeGeometry"]); + auto buildGeometry = deserializeField(obj["buildGeometry"]); + std::optional radius; + if (obj.count("radius")) { + radius = deserializeField(obj["radius"]); + } return std::make_shared( deserializePlanNodeId(obj), JoinTypeName::toJoinType(obj["joinType"].asString()), joinCondition, + probeGeometry, + buildGeometry, + radius, sources[0], sources[1], outputType); @@ -3425,6 +3463,25 @@ void PlanNode::toSkeletonString( } } +// static +const PlanNode* PlanNode::findFirstNode( + const PlanNode* root, + const std::function& predicate) { + VELOX_CHECK_NOT_NULL(root); + if (predicate(root)) { + return root; + } + + // Recursively go further through the sources. + for (const auto& source : root->sources()) { + const auto* ret = PlanNode::findFirstNode(source.get(), predicate); + if (ret != nullptr) { + return ret; + } + } + return nullptr; +} + namespace { void collectLeafPlanNodeIds( const PlanNode& planNode, @@ -3587,7 +3644,7 @@ folly::dynamic IndexLookupCondition::serialize() const { } bool InIndexLookupCondition::isFilter() const { - return std::dynamic_pointer_cast(list) != nullptr; + return list->isConstantKind(); } folly::dynamic InIndexLookupCondition::serialize() const { @@ -3605,16 +3662,13 @@ void InIndexLookupCondition::validate() const { VELOX_CHECK_NOT_NULL(key); VELOX_CHECK_NOT_NULL(list); VELOX_CHECK( - std::dynamic_pointer_cast(list) || - std::dynamic_pointer_cast(list), + list->isFieldAccessKind() || list->isConstantKind(), "Invalid condition list {}", list->toString()); - const auto listType = - std::dynamic_pointer_cast(list->type()); - VELOX_CHECK_NOT_NULL(listType); + const auto& listType = list->type()->asArray(); VELOX_CHECK_EQ( key->type()->kind(), - listType->elementType()->kind(), + listType.elementType()->kind(), "In condition key and list condition element must have the same type"); } @@ -3632,9 +3686,7 @@ IndexLookupConditionPtr InIndexLookupCondition::create( } bool BetweenIndexLookupCondition::isFilter() const { - return (std::dynamic_pointer_cast(lower) != - nullptr) && - (std::dynamic_pointer_cast(upper) != nullptr); + return lower->isConstantKind() && upper->isConstantKind(); } folly::dynamic BetweenIndexLookupCondition::serialize() const { @@ -3669,14 +3721,12 @@ void BetweenIndexLookupCondition::validate() const { VELOX_CHECK_NOT_NULL(lower); VELOX_CHECK_NOT_NULL(upper); VELOX_CHECK( - std::dynamic_pointer_cast(lower) || - std::dynamic_pointer_cast(lower), + lower->isFieldAccessKind() || lower->isConstantKind(), "Invalid lower between condition {}", lower->toString()); VELOX_CHECK( - std::dynamic_pointer_cast(upper) || - std::dynamic_pointer_cast(upper), + upper->isFieldAccessKind() || upper->isConstantKind(), "Invalid upper between condition {}", upper->toString()); @@ -3692,7 +3742,7 @@ void BetweenIndexLookupCondition::validate() const { } bool EqualIndexLookupCondition::isFilter() const { - return std::dynamic_pointer_cast(value) != nullptr; + return value->isConstantKind(); } folly::dynamic EqualIndexLookupCondition::serialize() const { @@ -3719,7 +3769,7 @@ void EqualIndexLookupCondition::validate() const { VELOX_CHECK_NOT_NULL(key); VELOX_CHECK_NOT_NULL(value); VELOX_CHECK_NOT_NULL( - std::dynamic_pointer_cast(value), + checked_pointer_cast(value), "Equal condition value must be a constant expression: {}", value->toString()); diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index b5c82d29a2cb..cf8765812ff5 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -259,23 +259,28 @@ class PlanNode : public ISerializable { /// The name of the plan node, used in toString. virtual std::string_view name() const = 0; + template + bool is() const { + return dynamic_cast(this) != nullptr; + } + + template + const T* as() const { + return dynamic_cast(this); + } + /// Recursively checks the node tree for a first node that satisfy a given /// condition. Returns pointer to the node if found, nullptr if not. static const PlanNode* findFirstNode( - const PlanNode* node, - const std::function& predicate) { - if (predicate(node)) { - return node; - } + const PlanNode* root, + const std::function& predicate); - // Recursively go further through the sources. - for (const auto& source : node->sources()) { - const auto* ret = PlanNode::findFirstNode(source.get(), predicate); - if (ret != nullptr) { - return ret; - } - } - return nullptr; + /// @return PlanNode with matching ID or nullptr if not found. + static const PlanNode* findNodeById( + const PlanNode* root, + const PlanNodeId& id) { + return findFirstNode( + root, [&](const auto* node) { return node->id() == id; }); } private: @@ -1103,14 +1108,14 @@ class AggregationNode : public PlanNode { /// Optional name of input column to use as a mask. Column type must be /// BOOLEAN. - FieldAccessTypedExprPtr mask; + FieldAccessTypedExprPtr mask{}; /// Optional list of input columns to sort by before applying aggregate /// function. - std::vector sortingKeys; + std::vector sortingKeys{}; /// A list of sorting orders that goes together with 'sortingKeys'. - std::vector sortingOrders; + std::vector sortingOrders{}; /// Boolean indicating whether inputs must be de-duplicated before /// aggregating. @@ -3431,7 +3436,15 @@ class IndexLookupJoinNode : public AbstractJoinNode { public: /// @param joinType Specifies the lookup join type. Only INNER and LEFT joins /// are supported. - /// @param includeMatchColumn if true, the output type includes a boolean + /// @param leftKeys Left side join keys used for index lookup. + /// @param rightKeys Right side join keys that form the index prefix. + /// @param joinConditions Additional conditions for index lookup that can't + /// be converted into simple equality join conditions. These conditions use + /// columns from both left and right and exactly one index column from + /// the right side.sides + /// @param filter Additional filter to apply on join results. This supports + /// filters that can't be converted into join conditions. + /// @param hasMarker if true, the output type includes a boolean /// column at the end to indicate if a join output row has a match or not. /// This only applies for left join. IndexLookupJoinNode( @@ -3440,7 +3453,8 @@ class IndexLookupJoinNode : public AbstractJoinNode { const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + TypedExprPtr filter, + bool hasMarker, PlanNodePtr left, TableScanNodePtr right, RowTypePtr outputType); @@ -3453,16 +3467,27 @@ class IndexLookupJoinNode : public AbstractJoinNode { explicit Builder(const IndexLookupJoinNode& other) : AbstractJoinNode::Builder(other) { joinConditions_ = other.joinConditions(); + filter_ = other.filter(); + hasMarker_ = other.hasMarker(); } + /// Set lookup conditions for index lookup that can't be converted into + /// simple equality join conditions. Builder& joinConditions( std::vector joinConditions) { joinConditions_ = std::move(joinConditions); return *this; } - Builder& includeMatchColumn(bool includeMatchColumn) { - includeMatchColumn_ = includeMatchColumn; + /// Set additional filter to apply on join results. + Builder& filter(TypedExprPtr filter) { + filter_ = std::move(filter); + return *this; + } + + /// Set whether to include a marker column for left joins. + Builder& hasMarker(bool hasMarker) { + hasMarker_ = hasMarker; return *this; } @@ -3480,25 +3505,23 @@ class IndexLookupJoinNode : public AbstractJoinNode { right_.has_value(), "IndexLookupJoinNode right source is not set"); VELOX_USER_CHECK( outputType_.has_value(), "IndexLookupJoinNode outputType is not set"); - VELOX_USER_CHECK( - joinConditions_.has_value(), - "IndexLookupJoinNode join conditions are not set"); return std::make_shared( id_.value(), joinType_.value(), leftKeys_.value(), rightKeys_.value(), - joinConditions_.value(), - includeMatchColumn_, + joinConditions_, + filter_.value_or(nullptr), + hasMarker_, left_.value(), std::dynamic_pointer_cast(right_.value()), outputType_.value()); } private: - std::optional> joinConditions_; - bool includeMatchColumn_; + std::vector joinConditions_; + bool hasMarker_; }; bool supportsBarrier() const override { @@ -3509,6 +3532,8 @@ class IndexLookupJoinNode : public AbstractJoinNode { return lookupSourceNode_; } + /// Returns the join conditions for index lookup that can't be converted into + /// simple equality join conditions. const std::vector& joinConditions() const { return joinConditions_; } @@ -3517,8 +3542,9 @@ class IndexLookupJoinNode : public AbstractJoinNode { return "IndexLookupJoin"; } - bool includeMatchColumn() const { - return includeMatchColumn_; + /// Returns whether this node includes a marker column for left joins. + bool hasMarker() const { + return hasMarker_; } void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) @@ -3534,11 +3560,16 @@ class IndexLookupJoinNode : public AbstractJoinNode { private: void addDetails(std::stringstream& stream) const override; + /// The table scan node that provides the lookup source for index operations. const TableScanNodePtr lookupSourceNode_; + /// Join conditions that can't be converted into simple equality join + /// conditions. These conditions involve columns from both left and right + /// sides and exactly one index column from the right side. const std::vector joinConditions_; - const bool includeMatchColumn_; + /// Whether to include a marker column for left joins to indicate matches. + const bool hasMarker_; }; using IndexLookupJoinNodePtr = std::shared_ptr; @@ -3852,10 +3883,29 @@ class SpatialJoinNode : public PlanNode { const PlanNodeId& id, JoinType joinType, TypedExprPtr joinCondition, + FieldAccessTypedExprPtr probeGeometry, + FieldAccessTypedExprPtr buildGeometry, + std::optional radius, PlanNodePtr left, PlanNodePtr right, RowTypePtr outputType); + SpatialJoinNode( + const PlanNodeId& id, + JoinType joinType, + TypedExprPtr joinCondition, + PlanNodePtr left, + PlanNodePtr right, + RowTypePtr outputType); + + PlanNodePtr leftNode() const { + return sources()[0]; + } + + PlanNodePtr rightNode() const { + return sources()[1]; + } + class Builder { public: Builder() = default; @@ -3864,6 +3914,9 @@ class SpatialJoinNode : public PlanNode { id_ = other.id(); joinType_ = other.joinType(); joinCondition_ = other.joinCondition(); + probeGeometry_ = other.probeGeometry(); + buildGeometry_ = other.buildGeometry(); + radius_ = other.radius(); VELOX_CHECK_EQ(other.sources().size(), 2); left_ = other.sources()[0]; right_ = other.sources()[1]; @@ -3885,6 +3938,21 @@ class SpatialJoinNode : public PlanNode { return *this; } + Builder& probeGeometry(FieldAccessTypedExprPtr probeGeometry) { + probeGeometry_ = std::move(probeGeometry); + return *this; + } + + Builder& buildGeometry(FieldAccessTypedExprPtr buildGeometry) { + buildGeometry_ = std::move(buildGeometry); + return *this; + } + + Builder& radius(FieldAccessTypedExprPtr radius) { + radius_ = std::move(radius); + return *this; + } + Builder& left(PlanNodePtr left) { left_ = std::move(left); return *this; @@ -3908,11 +3976,38 @@ class SpatialJoinNode : public PlanNode { right_.has_value(), "SpatialJoinNode right source is not set"); VELOX_USER_CHECK( outputType_.has_value(), "SpatialJoinNode outputType is not set"); + VELOX_USER_CHECK( + probeGeometry_.has_value(), + "SpatialJoinNode probe geometry is not set"); + VELOX_USER_CHECK( + buildGeometry_.has_value(), + "SpatialJoinNode build geometry is not set"); + + VELOX_USER_CHECK( + (probeGeometry_.has_value() && buildGeometry_.has_value()) || + (!probeGeometry_.has_value() && !buildGeometry_.has_value()), + "Either probe and build geometry must both be set, or neither"); + + if (probeGeometry_.has_value() && buildGeometry_.has_value()) { + return std::make_shared( + id_.value(), + joinType_, + joinCondition_, + probeGeometry_.value(), + buildGeometry_.value(), + radius_, + left_.value(), + right_.value(), + outputType_.value()); + } return std::make_shared( id_.value(), joinType_, joinCondition_, + probeGeometry_.value(), + buildGeometry_.value(), + radius_, left_.value(), right_.value(), outputType_.value()); @@ -3922,6 +4017,9 @@ class SpatialJoinNode : public PlanNode { std::optional id_; JoinType joinType_ = kDefaultJoinType; TypedExprPtr joinCondition_; + std::optional probeGeometry_; + std::optional buildGeometry_; + std::optional radius_; std::optional left_; std::optional right_; std::optional outputType_; @@ -3946,6 +4044,18 @@ class SpatialJoinNode : public PlanNode { return joinCondition_; } + const FieldAccessTypedExprPtr& probeGeometry() const { + return probeGeometry_; + } + + const FieldAccessTypedExprPtr& buildGeometry() const { + return buildGeometry_; + } + + const std::optional& radius() const { + return radius_; + } + JoinType joinType() const { return joinType_; } @@ -3964,6 +4074,9 @@ class SpatialJoinNode : public PlanNode { const JoinType joinType_; const TypedExprPtr joinCondition_; + const FieldAccessTypedExprPtr probeGeometry_; + const FieldAccessTypedExprPtr buildGeometry_; + const std::optional radius_; const std::vector sources_; const RowTypePtr outputType_; }; @@ -4245,17 +4358,17 @@ class UnnestNode : public PlanNode { /// names must appear in the same order as unnestVariables. /// @param ordinalityName Optional name for the ordinality columns. If not /// present, ordinality column is not produced. - /// @param emptyUnnestValueName Optional name for column which indicates an - /// output row has empty unnest value or not. If not present, emptyUnnestValue - /// column is not provided and the unnest operator also skips producing output - /// rows with empty unnest value. + /// @param markerName Optional name for column which indicates whether an + /// output row has non-empty unnested value. If not present, marker column is + /// not provided and the unnest operator also skips producing output rows + /// with empty unnest value. UnnestNode( const PlanNodeId& id, std::vector replicateVariables, std::vector unnestVariables, std::vector unnestNames, std::optional ordinalityName, - std::optional emptyUnnestValueName, + std::optional markerName, const PlanNodePtr& source); class Builder { @@ -4304,9 +4417,8 @@ class UnnestNode : public PlanNode { return *this; } - Builder& emptyUnnestValueName( - std::optional emptyUnnestValueName) { - emptyUnnestValueName_ = std::move(emptyUnnestValueName); + Builder& markerName(std::optional markerName) { + markerName_ = std::move(markerName); return *this; } @@ -4328,7 +4440,7 @@ class UnnestNode : public PlanNode { unnestVariables_.value(), unnestNames_.value(), ordinalityName_, - emptyUnnestValueName_, + markerName_, source_.value()); } @@ -4338,7 +4450,7 @@ class UnnestNode : public PlanNode { std::optional> unnestVariables_; std::optional> unnestNames_; std::optional ordinalityName_; - std::optional emptyUnnestValueName_; + std::optional markerName_; std::optional source_; }; @@ -4380,12 +4492,12 @@ class UnnestNode : public PlanNode { return ordinalityName_.has_value(); } - const std::optional& emptyUnnestValueName() const { - return emptyUnnestValueName_; + const std::optional& markerName() const { + return markerName_; } - bool hasEmptyUnnestValue() const { - return emptyUnnestValueName_.has_value(); + bool hasMarker() const { + return markerName_.has_value(); } std::string_view name() const override { @@ -4403,7 +4515,7 @@ class UnnestNode : public PlanNode { const std::vector unnestVariables_; const std::vector unnestNames_; const std::optional ordinalityName_; - const std::optional emptyUnnestValueName_; + const std::optional markerName_; const std::vector sources_; RowTypePtr outputType_; }; diff --git a/velox/core/QueryConfig.cpp b/velox/core/QueryConfig.cpp index 3d5b25ff9487..736332280bb6 100644 --- a/velox/core/QueryConfig.cpp +++ b/velox/core/QueryConfig.cpp @@ -22,29 +22,26 @@ namespace facebook::velox::core { -QueryConfig::QueryConfig( - const std::unordered_map& values) - : config_{std::make_unique( - std::unordered_map(values))} { - validateConfig(); -} +QueryConfig::QueryConfig(std::unordered_map values) + : QueryConfig{ + ConfigTag{}, + std::make_shared(std::move(values))} {} -QueryConfig::QueryConfig(std::unordered_map&& values) - : config_{std::make_unique(std::move(values))} { +QueryConfig::QueryConfig( + ConfigTag /*tag*/, + std::shared_ptr config) + : config_{std::move(config)} { validateConfig(); } void QueryConfig::validateConfig() { // Validate if timezone name can be recognized. - if (config_->valueExists(QueryConfig::kSessionTimezone)) { + if (auto tz = config_->get(QueryConfig::kSessionTimezone)) { VELOX_USER_CHECK( - tz::getTimeZoneID( - config_->get(QueryConfig::kSessionTimezone).value(), - false) != -1, - fmt::format( - "session '{}' set with invalid value '{}'", - QueryConfig::kSessionTimezone, - config_->get(QueryConfig::kSessionTimezone).value())); + tz::getTimeZoneID(*tz, false) != -1, + "session '{}' set with invalid value '{}'", + QueryConfig::kSessionTimezone, + *tz); } } diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index 02cde3df9b46..4efb6542b031 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -21,16 +21,21 @@ namespace facebook::velox::core { -/// A simple wrapper around velox::ConfigBase. Defines constants for query +/// A simple wrapper around velox::IConfig. Defines constants for query /// config properties and accessor methods. /// Create per query context. Does not have a singleton instance. /// Does not allow altering properties on the fly. Only at creation time. class QueryConfig { public: - explicit QueryConfig( - const std::unordered_map& values); + explicit QueryConfig(std::unordered_map values); + + // This is needed only to resolve correct ctor for cases like + // QueryConfig{{}} or QueryConfig({}). + struct ConfigTag {}; - explicit QueryConfig(std::unordered_map&& values); + explicit QueryConfig( + ConfigTag /*tag*/, + std::shared_ptr config); /// Maximum memory that a query can use on a single host. static constexpr const char* kQueryMaxMemoryPerNode = @@ -40,6 +45,11 @@ class QueryConfig { /// name, e.g: "America/Los_Angeles". static constexpr const char* kSessionTimezone = "session_timezone"; + /// Session start time in milliseconds since Unix epoch. This represents when + /// the query session began execution. Used for functions that need to know + /// the session start time (e.g., current_date, localtime). + static constexpr const char* kSessionStartTime = "start_time"; + /// If true, timezone-less timestamp conversions (e.g. string to timestamp, /// when the string does not specify a timezone) will be adjusted to the user /// provided session timezone (if any). @@ -64,6 +74,19 @@ class QueryConfig { static constexpr const char* kExprTrackCpuUsage = "expression.track_cpu_usage"; + /// Controls whether non-deterministic expressions are deduplicated during + /// compilation. This is intended for testing and debugging purposes. By + /// default, this is set to true to preserve standard behavior. If set to + /// false, non-deterministic functions (such as rand()) will not be + /// deduplicated. Since non-deterministic functions may yield different + /// outputs on each call, disabling deduplication guarantees that the function + /// is executed only when the original expression is evaluated, rather than + /// being triggered for every deduplicated instance. This ensures each + /// invocation corresponds directly to the actual expression, maintaining + /// independent behavior for each call. + static constexpr const char* kExprDedupNonDeterministic = + "expression.dedup_non_deterministic"; + /// Whether to track CPU usage for stages of individual operators. True by /// default. Can be expensive when processing small batches, e.g. < 10K rows. static constexpr const char* kOperatorTrackCpuUsage = @@ -255,6 +278,13 @@ class QueryConfig { /// Window spilling flag, only applies if "spill_enabled" flag is set. static constexpr const char* kWindowSpillEnabled = "window_spill_enabled"; + /// When processing spilled window data, read batches of whole partitions + /// having at least that many rows. Set to 1 to read one whole partition at a + /// time. Each driver processing the Window operator will process that much + /// data at once. + static constexpr const char* kWindowSpillMinReadBatchRows = + "window_spill_min_read_batch_rows"; + /// If true, the memory arbitrator will reclaim memory from table writer by /// flushing its buffered data to disk. only applies if "spill_enabled" flag /// is set. @@ -693,10 +723,38 @@ class QueryConfig { /// username. static constexpr const char* kClientTags = "client_tags"; +#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY + /// Enable (reader) row size tracker as a fallback to file level row size + /// estimates. + static constexpr const char* kRowSizeTrackingEnabled = + "row_size_tracking_enabled"; +#endif + + /// Enable (reader) row size tracker as a fallback to file level row size + /// estimates. + static constexpr const char* kRowSizeTrackingMode = "row_size_tracking_mode"; + + enum class RowSizeTrackingMode { + DISABLED = 0, + EXCLUDE_DELTA_SPLITS = 1, + ENABLED_FOR_ALL = 2, + }; + bool selectiveNimbleReaderEnabled() const { return get(kSelectiveNimbleReaderEnabled, false); } +#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY + bool rowSizeTrackingEnabled() const { + return get(kRowSizeTrackingEnabled, true); + } +#endif + + RowSizeTrackingMode rowSizeTrackingMode() const { + return get( + kRowSizeTrackingMode, RowSizeTrackingMode::ENABLED_FOR_ALL); + } + bool debugDisableExpressionsWithPeeling() const { return get(kDebugDisableExpressionWithPeeling, false); } @@ -907,6 +965,12 @@ class QueryConfig { return get(kSessionTimezone, ""); } + /// Returns the session start time in milliseconds since Unix epoch. + /// If not set, returns 0 (or epoch). + int64_t sessionStartTimeMs() const { + return get(kSessionStartTime, 0); + } + bool exprEvalSimplified() const { return get(kExprEvalSimplified, false); } @@ -935,6 +999,10 @@ class QueryConfig { return get(kWindowSpillEnabled, true); } + uint32_t windowSpillMinReadBatchRows() const { + return get(kWindowSpillMinReadBatchRows, 1'000); + } + bool writerSpillEnabled() const { return get(kWriterSpillEnabled, true); } @@ -1101,6 +1169,10 @@ class QueryConfig { return get(kExprTrackCpuUsage, false); } + bool exprDedupNonDeterministic() const { + return get(kExprDedupNonDeterministic, true); + } + bool operatorTrackCpuUsage() const { return get(kOperatorTrackCpuUsage, true); } @@ -1258,9 +1330,14 @@ class QueryConfig { T get(const std::string& key, const T& defaultValue) const { return config_->get(key, defaultValue); } + template std::optional get(const std::string& key) const { - return std::optional(config_->get(key)); + return config_->get(key); + } + + const std::shared_ptr& config() const { + return config_; } /// Test-only method to override the current query config properties. @@ -1273,6 +1350,6 @@ class QueryConfig { private: void validateConfig(); - std::unique_ptr config_; + std::shared_ptr config_; }; } // namespace facebook::velox::core diff --git a/velox/core/tests/ConstantTypedExprTest.cpp b/velox/core/tests/ConstantTypedExprTest.cpp index 3d067f9ede87..75e1e317ef80 100644 --- a/velox/core/tests/ConstantTypedExprTest.cpp +++ b/velox/core/tests/ConstantTypedExprTest.cpp @@ -15,15 +15,179 @@ */ #include +#include "velox/common/memory/Memory.h" #include "velox/core/Expressions.h" #include "velox/functions/prestosql/types/HyperLogLogType.h" #include "velox/functions/prestosql/types/JsonType.h" #include "velox/functions/prestosql/types/TDigestType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/type/Variant.h" +#include "velox/vector/BaseVector.h" +#include "velox/vector/tests/utils/VectorTestBase.h" namespace facebook::velox::core::test { -TEST(ConstantTypedExprTest, null) { +namespace { +struct TestOpaqueStruct { + int value; + std::string name; + + TestOpaqueStruct(int v, std::string n) : value(v), name(std::move(n)) {} + + bool operator==(const TestOpaqueStruct& other) const { + return value == other.value && name == other.name; + } +}; + +} // namespace + +class ConstantTypedExprTest : public ::testing::Test, + public velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + pool_ = memory::memoryManager()->addLeafPool(); + + // Register serialization/deserialization functions needed for the tests + Type::registerSerDe(); + ITypedExpr::registerSerDe(); + + // Register OPAQUE type serialization for TestOpaqueStruct + static folly::once_flag once; + folly::call_once(once, []() { + OpaqueType::registerSerialization( + "TestOpaqueStruct", + [](const std::shared_ptr& obj) -> std::string { + return folly::json::serialize( + folly::dynamic::object("value", obj->value)("name", obj->name), + folly::json::serialization_opts{}); + }, + [](const std::string& json) -> std::shared_ptr { + folly::dynamic obj = folly::parseJson(json); + return std::make_shared( + obj["value"].asInt(), obj["name"].asString()); + }); + }); + } + + // Helper functions + template + std::shared_ptr createVariantExpr( + const TypePtr& type, + const T& value) { + return std::make_shared(type, variant(value)); + } + + std::shared_ptr createNullVariantExpr( + const TypePtr& type) { + return std::make_shared( + type, variant::null(type->kind())); + } + + std::shared_ptr createVectorExpr(const VectorPtr& vector) { + return std::make_shared(vector); + } + + template + VectorPtr createConstantVector(const TypePtr& type, const T& value) { + return BaseVector::createConstant(type, variant(value), 1, pool_.get()); + } + + VectorPtr createNullConstantVector(const TypePtr& type) { + return BaseVector::createNullConstant(type, 1, pool_.get()); + } + + // Test Data + struct TestValues { + variant nullValue; + std::vector nonNullValues; + + TestValues(TypeKind kind) : nullValue(variant::null(kind)) {} + }; + + TestValues getTestValues(TypeKind kind) { + TestValues values(kind); + + switch (kind) { + case TypeKind::BOOLEAN: + values.nonNullValues = {variant(true), variant(false)}; + break; + case TypeKind::TINYINT: + values.nonNullValues = { + variant(int8_t(0)), variant(int8_t(127)), variant(int8_t(-128))}; + break; + case TypeKind::SMALLINT: + values.nonNullValues = { + variant(int16_t(0)), + variant(int16_t(32767)), + variant(int16_t(-32768))}; + break; + case TypeKind::INTEGER: + values.nonNullValues = { + variant(int32_t(0)), + variant(int32_t(2147483647)), + variant(int32_t(-2147483648))}; + break; + case TypeKind::BIGINT: + values.nonNullValues = { + variant(int64_t(0)), + variant(int64_t(9223372036854775807LL)), + variant(int64_t(-9223372036854775808ULL))}; + break; + case TypeKind::REAL: + values.nonNullValues = {variant(0.0f), variant(3.14f), variant(-1.5f)}; + break; + case TypeKind::DOUBLE: + values.nonNullValues = { + variant(0.0), variant(3.14159), variant(-2.71828)}; + break; + case TypeKind::VARCHAR: + values.nonNullValues = { + variant(""), variant("hello"), variant("test string")}; + break; + case TypeKind::VARBINARY: + values.nonNullValues = { + variant::binary(""), + variant::binary("binary data"), + variant::binary("\x00\x01\x02")}; + break; + case TypeKind::TIMESTAMP: + values.nonNullValues = { + variant(Timestamp(0, 0)), + variant(Timestamp(1234567890, 123456789))}; + break; + case TypeKind::HUGEINT: + values.nonNullValues = { + variant(int128_t(0)), + variant(int128_t(123)), + variant(int128_t(-456))}; + break; + default: + // For complex types, we'll handle them within individual tests. + break; + } + return values; + } + + std::shared_ptr pool_; + const std::vector scalarTypes_ = { + TypeKind::BOOLEAN, + TypeKind::TINYINT, + TypeKind::SMALLINT, + TypeKind::INTEGER, + TypeKind::BIGINT, + TypeKind::REAL, + TypeKind::DOUBLE, + TypeKind::VARCHAR, + TypeKind::VARBINARY, + TypeKind::TIMESTAMP, + TypeKind::HUGEINT}; +}; + +TEST_F(ConstantTypedExprTest, null) { auto makeNull = [](const TypePtr& type) { return std::make_shared( type, variant::null(type->kind())); @@ -67,4 +231,352 @@ TEST(ConstantTypedExprTest, null) { *makeNull(ROW({"x", "y"}, {INTEGER(), REAL()}))); } +TEST_F(ConstantTypedExprTest, hashScalarTypes) { + // Tests the consistency of the hash value returned by the ConstantTypedExpr + // between its construction using variant and Velox vectors. + for (auto kind : scalarTypes_) { + auto type = createScalarType(kind); + auto testValues = getTestValues(kind); + + // null values + auto nullVariantExpr = createNullVariantExpr(type); + auto nullVectorExpr = createVectorExpr(createNullConstantVector(type)); + EXPECT_EQ(nullVariantExpr->hash(), nullVectorExpr->hash()) + << "Hash mismatch for null " << TypeKindName::toName(kind); + + // non-null values + for (const auto& value : testValues.nonNullValues) { + auto variantExpr = std::make_shared(type, value); + auto vectorExpr = createVectorExpr( + BaseVector::createConstant(type, value, 1, pool_.get())); + EXPECT_EQ(variantExpr->hash(), vectorExpr->hash()) + << "Hash mismatch for non-null " << TypeKindName::toName(kind) + << " with value " << value.toJson(type); + } + } +} + +TEST_F(ConstantTypedExprTest, hashComplexTypes) { + // ARRAY + auto arrayType = ARRAY(INTEGER()); + + // null values + auto nullArrayVariantExpr = createNullVariantExpr(arrayType); + auto nullArrayVectorExpr = + createVectorExpr(createNullConstantVector(arrayType)); + EXPECT_EQ(nullArrayVariantExpr->hash(), nullArrayVectorExpr->hash()) + << "Hash mismatch for null ARRAY variant vs vector"; + + // non-null values + auto arrayVariant = Variant::array({1, 2, 3}); + auto arrayVariantExpr = + std::make_shared(arrayType, arrayVariant); + auto arrayVector = makeArrayVector({{1, 2, 3}}); + auto arrayVectorExpr = createVectorExpr(arrayVector); + EXPECT_EQ(arrayVariantExpr->hash(), arrayVectorExpr->hash()) + << "Hash mismatch for non-null ARRAY variant vs vector"; + + // MAP + auto mapType = MAP(VARCHAR(), INTEGER()); + + // null values + auto nullMapVariantExpr = createNullVariantExpr(mapType); + auto nullMapVectorExpr = createVectorExpr(createNullConstantVector(mapType)); + EXPECT_EQ(nullMapVariantExpr->hash(), nullMapVectorExpr->hash()) + << "Hash mismatch for null MAP variant vs vector"; + + // non-null values + std::map mapData = {{"key1", 1}, {"key2", 2}}; + auto mapVariant = Variant::map(mapData); + auto mapVariantExpr = + std::make_shared(mapType, mapVariant); + auto mapVector = + makeMapVector({{{"key1", 1}, {"key2", 2}}}); + auto mapVectorExpr = createVectorExpr(mapVector); + EXPECT_EQ(mapVariantExpr->hash(), mapVectorExpr->hash()) + << "Hash mismatch for non-null MAP variant vs vector"; + + // ROW + auto rowType = ROW({{"a", INTEGER()}, {"b", VARCHAR()}}); + + // null values + auto nullRowVariantExpr = createNullVariantExpr(rowType); + auto nullRowVectorExpr = createVectorExpr(createNullConstantVector(rowType)); + EXPECT_EQ(nullRowVariantExpr->hash(), nullRowVectorExpr->hash()) + << "Hash mismatch for null ROW variant vs vector"; + + // non-null values + auto rowVariant = Variant::row({42, "hello"}); + auto rowVariantExpr = + std::make_shared(rowType, rowVariant); + auto rowVector = makeRowVector( + {makeFlatVector({42}), makeFlatVector({"hello"})}); + auto rowVectorExpr = createVectorExpr(rowVector); + EXPECT_EQ(rowVariantExpr->hash(), rowVectorExpr->hash()) + << "Hash mismatch for non-null ROW variant vs vector"; + + // OPAQUE + auto testObj = std::make_shared(42, "test_data"); + auto opaqueType = OPAQUE(); + + // null values + auto nullOpaqueVariantExpr = createNullVariantExpr(opaqueType); + auto nullOpaqueVectorExpr = + createVectorExpr(createNullConstantVector(opaqueType)); + EXPECT_EQ(nullOpaqueVariantExpr->hash(), nullOpaqueVectorExpr->hash()) + << "Hash mismatch for null OPAQUE"; + + // non-null values + auto opaqueVariant = Variant::opaque(testObj); + auto opaqueVariantExpr = + std::make_shared(opaqueType, opaqueVariant); + auto opaqueVectorExpr = createVectorExpr( + BaseVector::createConstant(opaqueType, opaqueVariant, 1, pool_.get())); + EXPECT_EQ(opaqueVariantExpr->hash(), opaqueVectorExpr->hash()) + << "Hash mismatch for non-null OPAQUE"; +} + +TEST_F(ConstantTypedExprTest, serdeScalarTypes) { + // Test serialize/deserialize APIs for scalar types to ensure backward + // compatibility. + for (auto kind : scalarTypes_) { + auto type = createScalarType(kind); + auto testValues = getTestValues(kind); + + // null values + auto nullVariantExpr = createNullVariantExpr(type); + auto serialized = nullVariantExpr->serialize(); + auto deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null variant " + << TypeKindName::toName(kind); + auto nullVectorExpr = createVectorExpr(createNullConstantVector(type)); + serialized = nullVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for null vector " + << TypeKindName::toName(kind); + + // non-null values + for (const auto& value : testValues.nonNullValues) { + auto variantExpr = std::make_shared(type, value); + serialized = variantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*variantExpr == *deserialized) + << "Serialize/deserialize mismatch for variant " + << TypeKindName::toName(kind); + + auto vectorExpr = createVectorExpr( + BaseVector::createConstant(type, value, 1, pool_.get())); + serialized = vectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*vectorExpr == *deserialized) + << "Serialize/deserialize mismatch for vector " + << TypeKindName::toName(kind) << " with value " << value.toJson(type); + } + } +} + +TEST_F(ConstantTypedExprTest, serdeComplexTypes) { + // ARRAY + auto arrayType = ARRAY(INTEGER()); + + // null values + auto nullArrayVariantExpr = createNullVariantExpr(arrayType); + auto serialized = nullArrayVariantExpr->serialize(); + auto deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullArrayVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null ARRAY variant"; + auto nullArrayVectorExpr = + createVectorExpr(createNullConstantVector(arrayType)); + serialized = nullArrayVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullArrayVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for null ARRAY vector"; + + // non-null values + auto arrayVariant = Variant::array({1, 2, 3}); + auto arrayVariantExpr = + std::make_shared(arrayType, arrayVariant); + serialized = arrayVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*arrayVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for ARRAY variant with data"; + auto arrayVector = makeArrayVector({{1, 2, 3}}); + auto arrayVectorExpr = createVectorExpr(arrayVector); + serialized = arrayVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*arrayVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for ARRAY vector with data"; + + // MAP + auto mapType = MAP(VARCHAR(), INTEGER()); + // null values + auto nullMapVariantExpr = createNullVariantExpr(mapType); + serialized = nullMapVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullMapVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null MAP variant"; + + // non-null values + std::map mapData = {{"key1", 1}, {"key2", 2}}; + auto mapVariant = Variant::map(mapData); + auto mapVariantExpr = + std::make_shared(mapType, mapVariant); + serialized = mapVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*mapVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for MAP variant with data"; + + // ROW + auto rowType = ROW({{"a", INTEGER()}, {"b", VARCHAR()}}); + // null values + auto nullRowVariantExpr = createNullVariantExpr(rowType); + serialized = nullRowVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullRowVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null ROW variant"; + + // non-null values + auto rowVariant = Variant::row({42, "hello"}); + auto rowVariantExpr = + std::make_shared(rowType, rowVariant); + serialized = rowVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*rowVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for ROW variant with data"; + + // OPAQUE + auto opaqueType = OPAQUE(); + + // null values + auto nullOpaqueVariantExpr = createNullVariantExpr(opaqueType); + serialized = nullOpaqueVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullOpaqueVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null OPAQUE variant"; + auto nullOpaqueVectorExpr = + createVectorExpr(createNullConstantVector(opaqueType)); + serialized = nullOpaqueVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullOpaqueVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for null OPAQUE vector"; + + // non-null values + auto testObj = std::make_shared(42, "test_data"); + auto opaqueVariant = Variant::opaque(testObj); + auto opaqueVariantExpr = + std::make_shared(opaqueType, opaqueVariant); + serialized = opaqueVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + auto actualObj = static_pointer_cast(deserialized) + ->value() + .value() + .obj; + EXPECT_EQ(*testObj, *static_pointer_cast(actualObj)); +} + +TEST_F(ConstantTypedExprTest, toStringScalarTypes) { + for (auto kind : scalarTypes_) { + auto type = createScalarType(kind); + auto testValues = getTestValues(kind); + + // null values + auto nullVariantExpr = createNullVariantExpr(type); + auto nullVectorExpr = createVectorExpr(createNullConstantVector(type)); + EXPECT_EQ(nullVariantExpr->toString(), nullVectorExpr->toString()) + << "toString mismatch for null " << TypeKindName::toName(kind); + + // non-null values + for (const auto& value : testValues.nonNullValues) { + auto variantExpr = std::make_shared(type, value); + auto vectorExpr = createVectorExpr( + BaseVector::createConstant(type, value, 1, pool_.get())); + EXPECT_EQ(variantExpr->toString(), vectorExpr->toString()) + << "toString mismatch for " << TypeKindName::toName(kind) + << " with value " << value.toJson(type); + } + } +} + +TEST_F(ConstantTypedExprTest, toStringComplexTypes) { + // ARRAY + auto arrayType = ARRAY(INTEGER()); + + // null values + auto nullArrayVariantExpr = createNullVariantExpr(arrayType); + auto nullArrayVectorExpr = + createVectorExpr(createNullConstantVector(arrayType)); + EXPECT_EQ(nullArrayVariantExpr->toString(), nullArrayVectorExpr->toString()) + << "toString mismatch for null ARRAY"; + + // non-null values + auto arrayVariant = Variant::array({1, 2, 3}); + auto arrayVariantExpr = + std::make_shared(arrayType, arrayVariant); + auto arrayVector = makeArrayVector({{1, 2, 3}}); + auto arrayVectorExpr = createVectorExpr(arrayVector); + EXPECT_EQ(arrayVariantExpr->toString(), arrayVectorExpr->toString()) + << "toString mismatch for ARRAY variant vs vector"; + + // MAP + auto mapType = MAP(VARCHAR(), INTEGER()); + + // null values + auto nullMapVariantExpr = createNullVariantExpr(mapType); + auto nullMapVectorExpr = createVectorExpr(createNullConstantVector(mapType)); + EXPECT_EQ(nullMapVariantExpr->toString(), nullMapVectorExpr->toString()) + << "toString mismatch for null MAP"; + + // non-null values + std::map mapData = {{"key1", 1}, {"key2", 2}}; + auto mapVariant = Variant::map(mapData); + auto mapVariantExpr = + std::make_shared(mapType, mapVariant); + auto mapVector = + makeMapVector({{{"key1", 1}, {"key2", 2}}}); + auto mapVectorExpr = createVectorExpr(mapVector); + EXPECT_EQ(mapVariantExpr->toString(), mapVectorExpr->toString()) + << "toString mismatch for MAP variant vs vector"; + + // ROW + auto rowType = ROW({{"a", INTEGER()}, {"b", VARCHAR()}}); + + // null values + auto nullRowVariantExpr = createNullVariantExpr(rowType); + auto nullRowVectorExpr = createVectorExpr(createNullConstantVector(rowType)); + EXPECT_EQ(nullRowVariantExpr->toString(), nullRowVectorExpr->toString()) + << "toString mismatch for null ROW"; + + // non-null values + auto rowVariant = Variant::row({42, "hello"}); + auto rowVariantExpr = + std::make_shared(rowType, rowVariant); + auto rowVector = makeRowVector( + {makeFlatVector({42}), makeFlatVector({"hello"})}); + auto rowVectorExpr = createVectorExpr(rowVector); + EXPECT_EQ(rowVariantExpr->toString(), rowVectorExpr->toString()) + << "toString mismatch for ROW variant vs vector"; + + // OPAQUE + auto opaqueType = OPAQUE(); + + // null values + auto nullOpaqueVariantExpr = createNullVariantExpr(opaqueType); + auto nullOpaqueVectorExpr = + createVectorExpr(createNullConstantVector(opaqueType)); + EXPECT_EQ(nullOpaqueVariantExpr->toString(), nullOpaqueVectorExpr->toString()) + << "toString mismatch for null OPAQUE"; + + // non-null values + auto testObj = std::make_shared(42, "test_data"); + auto opaqueVariant = Variant::opaque(testObj); + auto opaqueVariantExpr = + std::make_shared(opaqueType, opaqueVariant); + auto opaqueVectorExpr = createVectorExpr( + BaseVector::createConstant(opaqueType, opaqueVariant, 1, pool_.get())); + EXPECT_EQ(opaqueVariantExpr->toString(), opaqueVectorExpr->toString()) + << "toString mismatch for OPAQUE variant vs vector"; +} + } // namespace facebook::velox::core::test diff --git a/velox/core/tests/PlanConsistencyCheckerTest.cpp b/velox/core/tests/PlanConsistencyCheckerTest.cpp index 41c7992b7ddb..9aa0f7f3ff0f 100644 --- a/velox/core/tests/PlanConsistencyCheckerTest.cpp +++ b/velox/core/tests/PlanConsistencyCheckerTest.cpp @@ -33,7 +33,7 @@ TypedExprPtr Lit(Variant value) { return std::make_shared(std::move(type), std::move(value)); } -TypedExprPtr Col(TypePtr type, std::string name) { +FieldAccessTypedExprPtr Col(TypePtr type, std::string name) { return std::make_shared( std::move(type), std::move(name)); } @@ -100,5 +100,115 @@ TEST_F(PlanConsistencyCheckerTest, project) { PlanConsistencyChecker::check(projectNode), "Field not found: x"); } +TEST_F(PlanConsistencyCheckerTest, aggregation) { + auto valuesNode = + std::make_shared("0", std::vector{}); + + auto projectNode = std::make_shared( + "1", + std::vector{"a", "b", "c"}, + std::vector{Lit(true), Lit(1), Lit(0.1)}, + valuesNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(projectNode)); + + { + auto aggregationNode = std::make_shared( + "2", + AggregationNode::Step::kPartial, + std::vector{}, + std::vector{}, + std::vector{"sum", "cnt"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "x")), + .rawInputTypes = {BIGINT()}, + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys*/ false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), "Field not found: x"); + } + + { + auto aggregationNode = std::make_shared( + "2", + AggregationNode::Step::kPartial, + std::vector{Col(INTEGER(), "y")}, + std::vector{}, + std::vector{"sum", "cnt"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "b")), + .rawInputTypes = {BIGINT()}, + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys*/ false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), "Field not found: y"); + } + + { + auto aggregationNode = std::make_shared( + "2", + AggregationNode::Step::kPartial, + std::vector{}, + std::vector{}, + std::vector{"sum", "cnt"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "b")), + .rawInputTypes = {BIGINT()}, + .mask = Col(BOOLEAN(), "z"), + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys*/ false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), "Field not found: z"); + } + + { + auto aggregationNode = std::make_shared( + "2", + AggregationNode::Step::kPartial, + std::vector{}, + std::vector{}, + std::vector{"sum", "sum"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "b")), + .rawInputTypes = {BIGINT()}, + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys*/ false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), + "Duplicate output column: sum"); + } +} + } // namespace } // namespace facebook::velox::core diff --git a/velox/core/tests/PlanNodeBuilderTest.cpp b/velox/core/tests/PlanNodeBuilderTest.cpp index 840baacb7129..6e624f722628 100644 --- a/velox/core/tests/PlanNodeBuilderTest.cpp +++ b/velox/core/tests/PlanNodeBuilderTest.cpp @@ -882,24 +882,34 @@ TEST_F(PlanNodeBuilderTest, spatialJoinNode) { const auto joinType = JoinType::kInner; const auto joinCondition = std::make_shared(BOOLEAN(), variant(true)); - const auto left = - ValuesNode::Builder() - .id("values_node_id_1") - .values({makeRowVector( - {"c0"}, {makeFlatVector(std::vector{1})})}) - .build(); - const auto right = - ValuesNode::Builder() - .id("values_node_id_2") - .values({makeRowVector( - {"c1"}, {makeFlatVector(std::vector{2})})}) - .build(); + const auto left = ValuesNode::Builder() + .id("values_node_id_1") + .values({makeRowVector( + {"c0", "g0"}, + {makeFlatVector(std::vector{1}), + makeFlatVector( + std::vector{"POINT(0 0)"})})}) + .build(); + const auto right = ValuesNode::Builder() + .id("values_node_id_2") + .values({makeRowVector( + {"c1", "g1"}, + {makeFlatVector(std::vector{2}), + makeFlatVector( + std::vector{"POINT(0 0)"})})}) + .build(); const auto outputType = ROW({"c0"}, {BIGINT()}); + const auto probeGeom = + std::make_shared(VARCHAR(), "g0"); + const auto buildGeom = + std::make_shared(VARCHAR(), "g1"); const auto verify = [&](const std::shared_ptr& node) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->joinType(), joinType); EXPECT_EQ(node->joinCondition(), joinCondition); + EXPECT_EQ(node->probeGeometry(), probeGeom); + EXPECT_EQ(node->buildGeometry(), buildGeom); EXPECT_EQ(node->sources()[0], left); EXPECT_EQ(node->sources()[1], right); EXPECT_EQ(node->outputType(), outputType); @@ -911,6 +921,8 @@ TEST_F(PlanNodeBuilderTest, spatialJoinNode) { .joinCondition(joinCondition) .left(left) .right(right) + .probeGeometry(probeGeom) + .buildGeometry(buildGeom) .outputType(outputType) .build(); verify(node); diff --git a/velox/core/tests/PlanNodeTest.cpp b/velox/core/tests/PlanNodeTest.cpp index c09d3b44cd76..2c4a0f90f6a3 100644 --- a/velox/core/tests/PlanNodeTest.cpp +++ b/velox/core/tests/PlanNodeTest.cpp @@ -85,6 +85,55 @@ TEST_F(PlanNodeTest, findFirstNode) { })); } +TEST_F(PlanNodeTest, findNodeById) { + auto values = std::make_shared("1", std::vector{}); + auto project = std::make_shared( + "2", + std::vector{"a", "b"}, + std::vector{ + std::make_shared(DOUBLE(), "rand"), + std::make_shared(DOUBLE(), "rand"), + }, + values); + + auto filter = std::make_shared( + "3", + std::make_shared( + BOOLEAN(), + "gt", + std::make_shared(DOUBLE(), "a"), + std::make_shared(DOUBLE(), 0.5)), + project); + + auto limit = std::make_shared("4", 0, 10, false, filter); + + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "1"), values.get()); + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "2"), project.get()); + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "3"), filter.get()); + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "4"), limit.get()); + + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "5"), nullptr); + ASSERT_EQ(PlanNode::findNodeById(project.get(), "4"), nullptr); +} + +TEST_F(PlanNodeTest, is) { + auto values = std::make_shared("1", std::vector{}); + auto project = std::make_shared( + "2", + std::vector{"a", "b"}, + std::vector{ + std::make_shared(DOUBLE(), "rand"), + std::make_shared(DOUBLE(), "rand"), + }, + values); + + ASSERT_TRUE(values->is()); + ASSERT_FALSE(values->is()); + + ASSERT_FALSE(project->is()); + ASSERT_TRUE(project->is()); +} + TEST_F(PlanNodeTest, sortOrder) { struct { SortOrder order1; @@ -132,6 +181,7 @@ TEST_F(PlanNodeTest, duplicateSortKeys) { "orderBy", sortingKeys, sortingOrders, false, nullptr), "Duplicate sorting keys are not allowed: c0"); } + class TestIndexTableHandle : public connector::ConnectorTableHandle { public: TestIndexTableHandle() @@ -163,7 +213,7 @@ class TestIndexTableHandle : public connector::ConnectorTableHandle { } }; -TEST_F(PlanNodeTest, isIndexLookupJoin) { +TEST_F(PlanNodeTest, indexLookupJoin) { const auto rowType = ROW({"name"}, {BIGINT()}); const auto valueNode = std::make_shared("orderBy", rowData_); ASSERT_FALSE(isIndexLookupJoin(valueNode.get())); @@ -193,12 +243,17 @@ TEST_F(PlanNodeTest, isIndexLookupJoin) { leftKeys, rightKeys, std::vector{}, - /*includeMatchColumn=*/false, + /*filter=*/nullptr, + /*hasMarker=*/false, probeNode, buildNode, outputType); ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithInnerJoin.get())); - ASSERT_FALSE(indexJoinNodeWithInnerJoin->includeMatchColumn()); + ASSERT_FALSE(indexJoinNodeWithInnerJoin->hasMarker()); + ASSERT_EQ(indexJoinNodeWithInnerJoin->filter(), nullptr); + ASSERT_EQ( + indexJoinNodeWithInnerJoin->toString(/*detailed=*/true), + "-- IndexLookupJoin[indexJoinNode][INNER c0=c1] -> c0:BIGINT, c1:BIGINT\n"); } { const RowTypePtr outputTypeWithMatchColumn = @@ -210,12 +265,39 @@ TEST_F(PlanNodeTest, isIndexLookupJoin) { leftKeys, rightKeys, std::vector{}, - /*includeMatchColumn=*/true, + /*filter=*/nullptr, + /*hasMarker=*/true, probeNode, buildNode, outputTypeWithMatchColumn); ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithLeftJoin.get())); - ASSERT_TRUE(indexJoinNodeWithLeftJoin->includeMatchColumn()); + ASSERT_TRUE(indexJoinNodeWithLeftJoin->hasMarker()); + ASSERT_EQ(indexJoinNodeWithLeftJoin->filter(), nullptr); + ASSERT_EQ( + indexJoinNodeWithLeftJoin->toString(/*detailed=*/true), + "-- IndexLookupJoin[indexJoinNode][LEFT c0=c1] -> c0:BIGINT, c1:BIGINT, c2:BOOLEAN\n"); + } + { + // Test IndexLookupJoinNode with filter + const auto filterExpr = std::make_shared( + BOOLEAN(), "filter_column"); + const auto indexJoinNodeWithFilter = std::make_shared( + "indexJoinNodeWithFilter", + core::JoinType::kInner, + leftKeys, + rightKeys, + std::vector{}, + /*filter=*/filterExpr, + /*hasMarker=*/false, + probeNode, + buildNode, + outputType); + ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithFilter.get())); + ASSERT_FALSE(indexJoinNodeWithFilter->hasMarker()); + ASSERT_EQ(indexJoinNodeWithFilter->filter(), filterExpr); + ASSERT_EQ( + indexJoinNodeWithFilter->toString(/*detailed=*/true), + "-- IndexLookupJoin[indexJoinNodeWithFilter][INNER c0=c1, filter: \"filter_column\"] -> c0:BIGINT, c1:BIGINT\n"); } // Error case. { @@ -226,7 +308,8 @@ TEST_F(PlanNodeTest, isIndexLookupJoin) { leftKeys, rightKeys, std::vector{}, - /*includeMatchColumn=*/true, + /*filter=*/nullptr, + /*hasMarker=*/true, probeNode, buildNode, outputType), @@ -240,7 +323,8 @@ TEST_F(PlanNodeTest, isIndexLookupJoin) { leftKeys, rightKeys, std::vector{}, - /*includeMatchColumn=*/true, + /*filter=*/nullptr, + /*hasMarker=*/true, probeNode, buildNode, outputType), @@ -256,7 +340,8 @@ TEST_F(PlanNodeTest, isIndexLookupJoin) { leftKeys, rightKeys, std::vector{}, - /*includeMatchColumn=*/true, + /*filter=*/nullptr, + /*hasMarker=*/true, probeNode, buildNode, outputTypeWithDuplicateMatchColumn), diff --git a/velox/core/tests/QueryConfigTest.cpp b/velox/core/tests/QueryConfigTest.cpp index daab3fad2a06..ba04e90007ff 100644 --- a/velox/core/tests/QueryConfigTest.cpp +++ b/velox/core/tests/QueryConfigTest.cpp @@ -205,4 +205,50 @@ TEST_F(QueryConfigTest, expressionEvaluationRelatedConfigs) { testConfig(createConfig(false, false, false, true)); } +TEST_F(QueryConfigTest, sessionStartTime) { + // Test with no session start time set + { + auto queryCtx = QueryCtx::create(nullptr, QueryConfig{{}}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), 0); + } + + // Test with session start time set + { + int64_t startTimeMs = 1674123456789; // Some timestamp in milliseconds + std::unordered_map configData( + {{QueryConfig::kSessionStartTime, std::to_string(startTimeMs)}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), startTimeMs); + } + + // Test with negative session start time (should be valid) + { + int64_t negativeStartTime = -1000; + std::unordered_map configData( + {{QueryConfig::kSessionStartTime, std::to_string(negativeStartTime)}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), negativeStartTime); + } + + // Test with maximum int64_t value + { + int64_t maxTime = std::numeric_limits::max(); + std::unordered_map configData( + {{QueryConfig::kSessionStartTime, std::to_string(maxTime)}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), maxTime); + } +} + } // namespace facebook::velox::core::test diff --git a/velox/docs/conf.py b/velox/docs/conf.py index e6401fa5fd2f..a87507e85dae 100644 --- a/velox/docs/conf.py +++ b/velox/docs/conf.py @@ -51,6 +51,7 @@ "pr", "spark", "iceberg", + "delta", "sphinx.ext.autodoc", "sphinx.ext.doctest", "sphinx.ext.mathjax", diff --git a/velox/docs/configs.rst b/velox/docs/configs.rst index 21eebecf8dfb..b7669ee5f197 100644 --- a/velox/docs/configs.rst +++ b/velox/docs/configs.rst @@ -64,7 +64,7 @@ Generic Configuration - bool - true - If true, the driver will collect the operator's input/output batch size through vector flat size estimation, otherwise not. - - We might turn this off in use cases which have very wide column width and batch size estimation has non-trivial cpu cost. + We might turn this off in use cases which have very wide column width and batch size estimation has non-trivial cpu cost. * - hash_adaptivity_enabled - bool - true @@ -200,7 +200,6 @@ Generic Configuration be expensive (especially if operator stats are retrieved frequently) and this allows the user to explicitly enable it. -.. _expression-evaluation-conf: Expression Evaluation Configuration ----------------------------------- @@ -335,6 +334,12 @@ Spilling - boolean - true - When `spill_enabled` is true, determines whether Window operator can spill to disk under memory pressure. + * - window_spill_min_read_batch_rows + - integer + - 1000 + - When processing spilled window data, read batches of whole partitions having at least that many rows. Set to 1 to + read one whole partition at a time. Each driver processing the Window operator will process that much data at + once. * - row_number_spill_enabled - boolean - true @@ -548,6 +553,31 @@ Table Writer - Minimum amount of data processed by all the logical table partitions to trigger skewed partition rebalancing by scale writer exchange. +Connector Config +---------------- +Connector config is initialized on velox runtime startup and is shared among queries as the default config across all connectors. +Each query can override the config by setting corresponding query session properties such as in Prestissimo. + +.. list-table:: + :widths: 20 20 10 10 70 + :header-rows: 1 + + * - user + - + - string + - "" + - The user of the query. Used for storage logging. + * - source + - + - string + - "" + - The source of the query. Used for storage access and logging. + * - schema + - + - string + - "" + - The schema of the query. Used for storage logging. + Hive Connector -------------- Hive Connector config is initialized on velox runtime startup and is shared among queries as the default config. @@ -686,7 +716,6 @@ Each query can override the config by setting corresponding query session proper - false - Whether to preserve flat maps in memory as FlatMapVectors instead of converting them to MapVectors. This is only applied during data reading inside the DWRF and Nimble readers, not during downstream processing like expression evaluation etc. - ``ORC File Format Configuration`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. list-table:: @@ -970,6 +999,12 @@ These semantics are similar to the `Apache Hadoop-Aws module /oauth2/token`. + * - fs.azure.sas.token.renew.period.for.streams + - string + - 120 + - Specifies the period in seconds to re-use SAS tokens until the expiry is within this number of seconds. + This configuration is used together with `registerSasTokenProvider` for dynamic SAS token renewal. + When a SAS token is close to expiry, it will be renewed by getting a new token from the provider. Presto-specific Configuration ----------------------------- diff --git a/velox/docs/develop/connectors.rst b/velox/docs/develop/connectors.rst index 9e550e007e7d..b4c4275cdb85 100644 --- a/velox/docs/develop/connectors.rst +++ b/velox/docs/develop/connectors.rst @@ -89,7 +89,7 @@ S3 is supported using the `AWS SDK for C++ ` S3 supported schemes are `s3://` (Amazon S3, Minio), `s3a://` (Hadoop 3.x), `s3n://` (Deprecated in Hadoop 3.x), `oss://` (Alibaba cloud storage), and `cos://`, `cosn://` (Tencent cloud storage). -HDFS is supported using the +HDFS is supported using the `Apache Hadoop libhdfs.so `_ and `Apache Hawk libhdfs3 `_ library. HDFS supported schemes are `hdfs://`. @@ -121,3 +121,8 @@ This is the behavior when the proxy settings are enabled: 4. The no_proxy/NO_PROXY list is comma separated. 5. Use . or \*. to indicate domain suffix matching, e.g. `.foobar.com` will match `test.foobar.com` or `foo.foobar.com`. + +HDFS Storage adapter +******************** + +Velox currently supports HDFS by dynamically loading libhdfs.so from the environment's ${HADOOP_HOME}/native/lib directory. If you prefer to use libhdfs3 instead, you can create a symbolic link from libhdfs.so to libhdfs3.so within the same directory. diff --git a/velox/docs/develop/types.rst b/velox/docs/develop/types.rst index 98106d931bb7..0dd4be596a8c 100644 --- a/velox/docs/develop/types.rst +++ b/velox/docs/develop/types.rst @@ -115,6 +115,7 @@ DATE INTEGER DECIMAL BIGINT if precision <= 18, HUGEINT if precision >= 19 INTERVAL DAY TO SECOND BIGINT INTERVAL YEAR TO MONTH INTEGER +TIME BIGINT ====================== ====================================================== DECIMAL type carries additional `precision`, @@ -130,6 +131,9 @@ upto 38 precision, with a range of :math:`[-10^{38} + 1, +10^{38} - 1]`. All the three values, precision, scale, unscaled value are required to represent a decimal value. +TIME type represents time in milliseconds from midnight UTC. Thus min/max value can range from UTC-14:00 at 00:00:00 to UTC+14:00 at 23:59:59.999 modulo 24 hours. +TIME type is backed by BIGINT physical type. + Custom Types ~~~~~~~~~~~~ Most custom types can be represented as logical types and can be built by extending @@ -178,6 +182,7 @@ TDIGEST VARBINARY QDIGEST VARBINARY BIGINT_ENUM BIGINT VARCHAR_ENUM VARCHAR +TIME WITH TIME ZONE BIGINT ======================== ===================== TIMESTAMP WITH TIME ZONE represents a time point in milliseconds precision @@ -243,6 +248,12 @@ VarcharEnumParameter as the key. Casting is only permitted to and from VARCHAR type, and is case-sensitive. Casting between different enum types is not permitted. Comparison operations are only allowed between values of the same enum type. +TIME WITH TIME ZONE represents time from midnight in milliseconds precision at a particular timezone. +Its physical type is BIGINT. The high 52 bits of bigint store signed integer for milliseconds in UTC. +The lower 12 bits store the time zone offsets minutes. This allows the time to be converted at any point of +time without ambiguity of daylight savings time. Time zone offsets range from -14:00 hours to +14:00 hours. + + Spark Types ~~~~~~~~~~~~ The `data types `_ in Spark have some semantic differences compared to those in diff --git a/velox/docs/ext/delta.py b/velox/docs/ext/delta.py new file mode 100644 index 000000000000..40a76426a2e0 --- /dev/null +++ b/velox/docs/ext/delta.py @@ -0,0 +1,773 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generating delta function link :delta:func:``""" + +from __future__ import annotations + +from function import ( + function_sig_re, + pairindextypes, + parse_arglist, + pseudo_parse_arglist, + parse_annotation, + ObjectEntry, + ModuleEntry, +) +from typing import Any, Iterable, Iterator, Tuple, cast + +from docutils import nodes +from docutils.nodes import Element, Node +from docutils.parsers.rst import directives +from sphinx import addnodes +from sphinx.addnodes import desc_signature, pending_xref +from sphinx.application import Sphinx +from sphinx.builders import Builder +from sphinx.directives import ObjectDescription +from sphinx.domains import Domain, Index, IndexEntry, ObjType +from sphinx.environment import BuildEnvironment +from sphinx.locale import _, __ +from sphinx.roles import XRefRole +from sphinx.util import logging +from sphinx.util.docfields import Field +from sphinx.util.nodes import ( + find_pending_xref_condition, + make_id, + make_refnode, +) +from sphinx.util.typing import OptionSpec + +logger = logging.getLogger(__name__) + +function_module = "delta" + + +class DeltaObject(ObjectDescription[Tuple[str, str]]): + """ + Description of a general Delta object. + + :cvar allow_nesting: Class is an object that allows for nested namespaces + :vartype allow_nesting: bool + """ + + option_spec: OptionSpec = { + "noindex": directives.flag, + "noindexentry": directives.flag, + "nocontentsentry": directives.flag, + "module": directives.unchanged, + "canonical": directives.unchanged, + "annotation": directives.unchanged, + } + + doc_field_types = [ + Field( + "returnvalue", + label=_("Returns"), + has_arg=False, + names=("returns", "return"), + ), + ] + + allow_nesting = False + + def get_signature_prefix(self, sig: str) -> list[nodes.Node]: + """May return a prefix to put before the object name in the + signature. + """ + return [] + + def needs_arglist(self) -> bool: + """May return true if an empty argument list is to be generated even if + the document contains none. + """ + return False + + def handle_signature(self, sig: str, signode: desc_signature) -> tuple[str, str]: + """Transform a Delta signature into RST nodes. + Return (fully qualified name of the thing, classname if any). + If inside a class, the current class name is handled intelligently: + * it is stripped from the displayed name if present + * it is added to the full name (return value) if not present + """ + m = function_sig_re.match(sig) + if m is None: + raise ValueError + prefix, name, arglist, retann = m.groups() + + # determine module and class name (if applicable), as well as full name + modname = self.options.get("module", self.env.ref_context.get("delta:module")) + classname = self.env.ref_context.get("delta:class") + if classname: + add_module = False + if prefix and (prefix == classname or prefix.startswith(classname + ".")): + fullname = prefix + name + # class name is given again in the signature + prefix = prefix[len(classname) :].lstrip(".") + elif prefix: + # class name is given in the signature, but different + # (shouldn't happen) + fullname = classname + "." + prefix + name + else: + # class name is not given in the signature + fullname = classname + "." + name + else: + add_module = True + if prefix: + classname = prefix.rstrip(".") + fullname = prefix + name + else: + classname = "" + fullname = name + + signode["module"] = modname + signode["class"] = classname + signode["fullname"] = fullname + + sig_prefix = self.get_signature_prefix(sig) + if sig_prefix: + if type(sig_prefix) is str: + raise TypeError( + "Python directive method get_signature_prefix()" + " must return a list of nodes." + f" Return value was '{sig_prefix}'." + ) + else: + signode += addnodes.desc_annotation(str(sig_prefix), "", *sig_prefix) + + if prefix: + signode += addnodes.desc_addname(prefix, prefix) + elif modname and add_module and self.env.config.add_module_names: + nodetext = modname + "." + signode += addnodes.desc_addname(nodetext, nodetext) + + signode += addnodes.desc_name(name, name) + if arglist: + try: + signode += parse_arglist(function_module, arglist, self.env) + except SyntaxError: + # fallback to parse arglist original parser. + # it supports to represent optional arguments (ex. "func(foo [, bar])") + pseudo_parse_arglist(signode, arglist) + except NotImplementedError as exc: + logger.warning( + "could not parse arglist (%r): %s", arglist, exc, location=signode + ) + pseudo_parse_arglist(signode, arglist) + else: + if self.needs_arglist(): + # for callables, add an empty parameter list + signode += addnodes.desc_parameterlist() + + if retann: + children = parse_annotation(function_module, retann, self.env) + signode += addnodes.desc_returns(retann, "", *children) + + anno = self.options.get("annotation") + if anno: + signode += addnodes.desc_annotation( + " " + anno, "", addnodes.desc_sig_space(), nodes.Text(anno) + ) + + return fullname, prefix + + def _object_hierarchy_parts(self, sig_node: desc_signature) -> tuple[str, ...]: + if "fullname" not in sig_node: + return () + modname = sig_node.get("module") + fullname = sig_node["fullname"] + + if modname: + return (modname, *fullname.split(".")) + else: + return tuple(fullname.split(".")) + + def get_index_text(self, modname: str, name: tuple[str, str]) -> str: + """Return the text for the index entry of the object.""" + raise NotImplementedError("must be implemented in subclasses") + + def add_target_and_index( + self, name_cls: tuple[str, str], sig: str, signode: desc_signature + ) -> None: + modname = self.options.get("module", self.env.ref_context.get("delta:module")) + fullname = (modname + "." if modname else "") + name_cls[0] + node_id = make_id(self.env, self.state.document, "", fullname) + signode["ids"].append(node_id) + self.state.document.note_explicit_target(signode) + + domain = cast(DeltaDomain, self.env.get_domain("delta")) + domain.note_object(fullname, self.objtype, node_id, location=signode) + + canonical_name = self.options.get("canonical") + if canonical_name: + domain.note_object( + canonical_name, self.objtype, node_id, aliased=True, location=signode + ) + + if "noindexentry" not in self.options: + indextext = self.get_index_text(modname, name_cls) + if indextext: + self.indexnode["entries"].append( + ("single", indextext, node_id, "", None) + ) + + def before_content(self) -> None: + """Handle object nesting before content + + For constructs that aren't nestable, the stack is bypassed, and instead + only the most recent object is tracked. This object prefix name will be + removed with :delta:meth:`after_content`. + """ + prefix = None + if self.names: + # fullname and name_prefix come from the `handle_signature` method. + # fullname represents the full object name that is constructed using + # object nesting and explicit prefixes. `name_prefix` is the + # explicit prefix given in a signature + (fullname, name_prefix) = self.names[-1] + if self.allow_nesting: + prefix = fullname + elif name_prefix: + prefix = name_prefix.strip(".") + if prefix: + self.env.ref_context["delta:class"] = prefix + if self.allow_nesting: + classes = self.env.ref_context.setdefault("delta:classes", []) + classes.append(prefix) + if "module" in self.options: + modules = self.env.ref_context.setdefault("delta:modules", []) + modules.append(self.env.ref_context.get("delta:module")) + self.env.ref_context["delta:module"] = self.options["module"] + + def after_content(self) -> None: + """Handle object de-nesting after content + + If this class is a nestable object, removing the last nested class prefix + ends further nesting in the object. + + If this class is not a nestable object, the list of classes should not + be altered as we didn't affect the nesting levels in + :delta:meth:`before_content`. + """ + classes = self.env.ref_context.setdefault("delta:classes", []) + if self.allow_nesting: + try: + classes.pop() + except IndexError: + pass + self.env.ref_context["delta:class"] = classes[-1] if len(classes) > 0 else None + if "module" in self.options: + modules = self.env.ref_context.setdefault("delta:modules", []) + if modules: + self.env.ref_context["delta:module"] = modules.pop() + else: + self.env.ref_context.pop("delta:module") + + def _toc_entry_name(self, sig_node: desc_signature) -> str: + if not sig_node.get("_toc_parts"): + return "" + + config = self.env.app.config + objtype = sig_node.parent.get("objtype") + if config.add_function_parentheses and objtype in {"function", "method"}: + parens = "()" + else: + parens = "" + *parents, name = sig_node["_toc_parts"] + if config.toc_object_entries_show_parents == "domain": + return sig_node.get("fullname", name) + parens + if config.toc_object_entries_show_parents == "hide": + return name + parens + if config.toc_object_entries_show_parents == "all": + return ".".join(parents + [name + parens]) + return "" + + +class DeltaFunction(DeltaObject): + """Description of a function.""" + + option_spec: OptionSpec = DeltaObject.option_spec.copy() + option_spec.update( + { + "async": directives.flag, + } + ) + + def get_signature_prefix(self, sig: str) -> list[nodes.Node]: + if "async" in self.options: + return [addnodes.desc_sig_keyword("", "async"), addnodes.desc_sig_space()] + else: + return [] + + def needs_arglist(self) -> bool: + return True + + def add_target_and_index( + self, name_cls: tuple[str, str], sig: str, signode: desc_signature + ) -> None: + super().add_target_and_index(name_cls, sig, signode) + if "noindexentry" not in self.options: + modname = self.options.get( + "module", self.env.ref_context.get("delta:module") + ) + node_id = signode["ids"][0] + + name, cls = name_cls + if modname: + text = _("%s() (in module %s)") % (name, modname) + self.indexnode["entries"].append(("single", text, node_id, "", None)) + else: + text = f"{pairindextypes['builtin']}; {name}()" + self.indexnode["entries"].append(("pair", text, node_id, "", None)) + + def get_index_text(self, modname: str, name_cls: tuple[str, str]) -> str | None: + # add index in own add_target_and_index() instead. + return None + + +class DeltaXRefRole(XRefRole): + def process_link( + self, + env: BuildEnvironment, + refnode: Element, + has_explicit_title: bool, + title: str, + target: str, + ) -> tuple[str, str]: + refnode["delta:module"] = env.ref_context.get("delta:module") + refnode["delta:class"] = env.ref_context.get("delta:class") + if not has_explicit_title: + title = title.lstrip(".") # only has a meaning for the target + target = target.lstrip("~") # only has a meaning for the title + # if the first character is a tilde, don't display the module/class + # parts of the contents + if title[0:1] == "~": + title = title[1:] + dot = title.rfind(".") + if dot != -1: + title = title[dot + 1 :] + # if the first character is a dot, search more specific namespaces first + # else search builtins first + if target[0:1] == ".": + target = target[1:] + refnode["refspecific"] = True + return title, target + + +class DeltaModuleIndex(Index): + """ + Index subclass to provide the Delta module index. + """ + + name = "modindex" + localname = _("Delta Module Index") + shortname = _("modules") + + def generate( + self, docnames: Iterable[str] | None = None + ) -> tuple[list[tuple[str, list[IndexEntry]]], bool]: + content: dict[str, list[IndexEntry]] = {} + # list of prefixes to ignore + ignores: list[str] = self.domain.env.config["modindex_common_prefix"] + ignores = sorted(ignores, key=len, reverse=True) + # list of all modules, sorted by module name + modules = sorted( + self.domain.data["modules"].items(), key=lambda x: x[0].lower() + ) + # sort out collapsible modules + prev_modname = "" + num_toplevels = 0 + for modname, (docname, node_id, synopsis, platforms, deprecated) in modules: + if docnames and docname not in docnames: + continue + + for ignore in ignores: + if modname.startswith(ignore): + modname = modname[len(ignore) :] + stripped = ignore + break + else: + stripped = "" + + # we stripped the whole module name? + if not modname: + modname, stripped = stripped, "" + + entries = content.setdefault(modname[0].lower(), []) + + package = modname.split(".")[0] + if package != modname: + # it's a submodule + if prev_modname == package: + # first submodule - make parent a group head + if entries: + last = entries[-1] + entries[-1] = IndexEntry( + last[0], 1, last[2], last[3], last[4], last[5], last[6] + ) + elif not prev_modname.startswith(package): + # submodule without parent in list, add dummy entry + entries.append( + IndexEntry(stripped + package, 1, "", "", "", "", "") + ) + subtype = 2 + else: + num_toplevels += 1 + subtype = 0 + + qualifier = _("Deprecated") if deprecated else "" + entries.append( + IndexEntry( + stripped + modname, + subtype, + docname, + node_id, + platforms, + qualifier, + synopsis, + ) + ) + prev_modname = modname + + # apply heuristics when to collapse modindex at page load: + # only collapse if number of toplevel modules is larger than + # number of submodules + collapse = len(modules) - num_toplevels < num_toplevels + + # sort by first letter + sorted_content = sorted(content.items()) + + return sorted_content, collapse + + +class DeltaDomain(Domain): + """Delta domain.""" + + name = "delta" + label = "Delta" + object_types: dict[str, ObjType] = { + "function": ObjType(_("function"), "func", "obj"), + } + + directives = { + "function": DeltaFunction, + } + roles = { + "func": DeltaXRefRole(fix_parens=True), + } + initial_data: dict[str, dict[str, tuple[Any]]] = { + "objects": {}, # fullname -> docname, objtype + "modules": {}, # modname -> docname, synopsis, platform, deprecated + } + indices = [ + DeltaModuleIndex, + ] + + @property + def objects(self) -> dict[str, ObjectEntry]: + return self.data.setdefault("objects", {}) # fullname -> ObjectEntry + + def note_object( + self, + name: str, + objtype: str, + node_id: str, + aliased: bool = False, + location: Any = None, + ) -> None: + """Note a delta object for cross reference. + + .. versionadded:: 2.1 + """ + if name in self.objects: + other = self.objects[name] + if other.aliased and aliased is False: + # The original definition found. Override it! + pass + elif other.aliased is False and aliased: + # The original definition is already registered. + return + else: + # duplicated + logger.warning( + __( + "duplicate object description of %s, " + "other instance in %s, use :noindex: for one of them" + ), + name, + other.docname, + location=location, + ) + self.objects[name] = ObjectEntry(self.env.docname, node_id, objtype, aliased) + + @property + def modules(self) -> dict[str, ModuleEntry]: + return self.data.setdefault("modules", {}) # modname -> ModuleEntry + + def note_module( + self, name: str, node_id: str, synopsis: str, platform: str, deprecated: bool + ) -> None: + """Note a delta module for cross reference. + + .. versionadded:: 2.1 + """ + self.modules[name] = ModuleEntry( + self.env.docname, node_id, synopsis, platform, deprecated + ) + + def clear_doc(self, docname: str) -> None: + for fullname, obj in list(self.objects.items()): + if obj.docname == docname: + del self.objects[fullname] + for modname, mod in list(self.modules.items()): + if mod.docname == docname: + del self.modules[modname] + + def merge_domaindata(self, docnames: list[str], otherdata: dict[str, Any]) -> None: + # XXX check duplicates? + for fullname, obj in otherdata["objects"].items(): + if obj.docname in docnames: + self.objects[fullname] = obj + for modname, mod in otherdata["modules"].items(): + if mod.docname in docnames: + self.modules[modname] = mod + + def find_obj( + self, + env: BuildEnvironment, + modname: str, + classname: str, + name: str, + type: str | None, + searchmode: int = 0, + ) -> list[tuple[str, ObjectEntry]]: + """Find a Delta object for "name", perhaps using the given module + and/or classname. Returns a list of (name, object entry) tuples. + """ + # skip parens + if name[-2:] == "()": + name = name[:-2] + + if not name: + return [] + + matches: list[tuple[str, ObjectEntry]] = [] + + newname = None + if searchmode == 1: + if type is None: + objtypes = list(self.object_types) + else: + objtypes = self.objtypes_for_role(type) + if objtypes is not None: + if modname and classname: + fullname = modname + "." + classname + "." + name + if ( + fullname in self.objects + and self.objects[fullname].objtype in objtypes + ): + newname = fullname + if not newname: + if ( + modname + and modname + "." + name in self.objects + and self.objects[modname + "." + name].objtype in objtypes + ): + newname = modname + "." + name + elif ( + name in self.objects and self.objects[name].objtype in objtypes + ): + newname = name + else: + # "fuzzy" searching mode + searchname = "." + name + matches = [ + (oname, self.objects[oname]) + for oname in self.objects + if oname.endswith(searchname) + and self.objects[oname].objtype in objtypes + ] + else: + # NOTE: searching for exact match, object type is not considered + if name in self.objects: + newname = name + elif type == "mod": + # only exact matches allowed for modules + return [] + elif classname and classname + "." + name in self.objects: + newname = classname + "." + name + elif modname and modname + "." + name in self.objects: + newname = modname + "." + name + elif ( + modname + and classname + and modname + "." + classname + "." + name in self.objects + ): + newname = modname + "." + classname + "." + name + if newname is not None: + matches.append((newname, self.objects[newname])) + return matches + + def resolve_xref( + self, + env: BuildEnvironment, + fromdocname: str, + builder: Builder, + type: str, + target: str, + node: pending_xref, + contnode: Element, + ) -> Element | None: + modname = node.get("delta:module") + clsname = node.get("delta:class") + searchmode = 1 if node.hasattr("refspecific") else 0 + matches = self.find_obj(env, modname, clsname, target, type, searchmode) + + if not matches and type == "attr": + # fallback to meth (for property; Sphinx-2.4.x) + # this ensures that `:attr:` role continues to refer to the old property entry + # that defined by ``method`` directive in old reST files. + matches = self.find_obj(env, modname, clsname, target, "meth", searchmode) + if not matches and type == "meth": + # fallback to attr (for property) + # this ensures that `:meth:` in the old reST files can refer to the property + # entry that defined by ``property`` directive. + # + # Note: _prop is a secret role only for internal look-up. + matches = self.find_obj(env, modname, clsname, target, "_prop", searchmode) + + if not matches: + return None + elif len(matches) > 1: + canonicals = [m for m in matches if not m[1].aliased] + if len(canonicals) == 1: + matches = canonicals + else: + logger.warning( + __("more than one target found for cross-reference %r: %s"), + target, + ", ".join(match[0] for match in matches), + type="ref", + subtype="python", + location=node, + ) + name, obj = matches[0] + + if obj[2] == "module": + return self._make_module_refnode(builder, fromdocname, name, contnode) + else: + # determine the content of the reference by conditions + content = find_pending_xref_condition(node, "resolved") + if content: + children = content.children + else: + # if not found, use contnode + children = [contnode] + + return make_refnode(builder, fromdocname, obj[0], obj[1], children, name) + + def resolve_any_xref( + self, + env: BuildEnvironment, + fromdocname: str, + builder: Builder, + target: str, + node: pending_xref, + contnode: Element, + ) -> list[tuple[str, Element]]: + modname = node.get("delta:module") + clsname = node.get("delta:class") + results: list[tuple[str, Element]] = [] + + # always search in "refspecific" mode with the :any: role + matches = self.find_obj(env, modname, clsname, target, None, 1) + multiple_matches = len(matches) > 1 + + for name, obj in matches: + if multiple_matches and obj.aliased: + # Skip duplicated matches + continue + + if obj[2] == "module": + results.append( + ( + "delta:mod", + self._make_module_refnode(builder, fromdocname, name, contnode), + ) + ) + else: + # determine the content of the reference by conditions + content = find_pending_xref_condition(node, "resolved") + if content: + children = content.children + else: + # if not found, use contnode + children = [contnode] + + results.append( + ( + "delta:" + self.role_for_objtype(obj[2]), + make_refnode( + builder, fromdocname, obj[0], obj[1], children, name + ), + ) + ) + return results + + def _make_module_refnode( + self, builder: Builder, fromdocname: str, name: str, contnode: Node + ) -> Element: + # get additional info for modules + module = self.modules[name] + title = name + if module.synopsis: + title += ": " + module.synopsis + if module.deprecated: + title += _(" (deprecated)") + if module.platform: + title += " (" + module.platform + ")" + return make_refnode( + builder, fromdocname, module.docname, module.node_id, contnode, title + ) + + def get_objects(self) -> Iterator[tuple[str, str, str, str, str, int]]: + for modname, mod in self.modules.items(): + yield (modname, modname, "module", mod.docname, mod.node_id, 0) + for refname, obj in self.objects.items(): + if obj.objtype != "module": # modules are already handled + if obj.aliased: + # aliased names are not full-text searchable. + yield (refname, refname, obj.objtype, obj.docname, obj.node_id, -1) + else: + yield (refname, refname, obj.objtype, obj.docname, obj.node_id, 1) + + def get_full_qualified_name(self, node: Element) -> str | None: + modname = node.get("delta:module") + clsname = node.get("delta:class") + target = node.get("reftarget") + if target is None: + return None + else: + return ".".join(filter(None, [modname, clsname, target])) + + +def setup(app: Sphinx) -> dict[str, Any]: + app.setup_extension("sphinx.directives") + app.add_domain(DeltaDomain) + + return { + "version": "builtin", + "env_version": 3, + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/velox/docs/functions/delta/functions.rst b/velox/docs/functions/delta/functions.rst new file mode 100644 index 000000000000..e1b3de5673fb --- /dev/null +++ b/velox/docs/functions/delta/functions.rst @@ -0,0 +1,13 @@ +******************** +Delta Lake Functions +******************** + +Here is a list of all scalar Delta Lake functions available in Velox. +Function names link to function description. + +These functions are used in deletion vector read. +Refer to `Delta Lake documentation `_ and `Delta Lake deletion vector blog `_ for details. + +.. delta:function:: bitmap_array_contains(bitmap_array: varbinary, input: bigint) -> bool + + Not implemented. diff --git a/velox/docs/functions/iceberg/functions.rst b/velox/docs/functions/iceberg/functions.rst index 768f79b43a55..3f5fa47feda2 100644 --- a/velox/docs/functions/iceberg/functions.rst +++ b/velox/docs/functions/iceberg/functions.rst @@ -22,3 +22,24 @@ Refer to `Iceberg documenation same type as input + + Returns the truncated value of the input based on the specified width. + For numeric values, truncate to the nearest lower multiple of ``width``, the truncate function is: input - (((input % width) + width) % width). + The ``width`` is used to truncate decimal values is applied using unscaled value to avoid additional (and potentially conflicting) parameters. + For string values, it truncates a valid UTF-8 string with no more than ``width`` code points. + In contrast to strings, binary values do not have an assumed encoding and are truncated to ``width`` bytes. + + Argument ``width`` must be a positive integer. + Supported types for ``input`` are: SHORTINT, TYNYINT, SMALLINT, INTEGER, BIGINT, DECIMAL, VARCHAR, VARBINARY. :: + + SELECT truncate(10, 11); -- 10 + SELECT truncate(10, -11); -- -20 + SELECT truncate(7, 22); -- 21 + SELECT truncate(0, 11); -- error: Reason: (0 vs. 0) Invalid truncate width\nExpression: width <= 0 + SELECT truncate(-3, 11); -- error: Reason: (-3 vs. 0) Invalid truncate width\nExpression: width <= 0 + SELECT truncate(4, 'iceberg'); -- 'iceb' + SELECT truncate(1, '测试'); -- 测 + SELECT truncate(6, '测试'); -- 测试 + SELECT truncate(6, cast('测试' as binary)); -- 测试_ diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index ffa394f6e4ca..5ace7644cc15 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -34,6 +34,7 @@ Array Functions Returns the average of all non-null elements of the array. If there are no non-null elements, returns null. .. function:: array_cum_sum(array(T)) -> array(T) + Returns the array whose elements are the cumulative sum of the input array, i.e. result[i] = input[1] + input[2] + … + input[i]. If there there is null elements in the array, the cumulative sum at and after the element is null. The following types are supported: int8_t, int16_t, int32_t, int64_t, int128_t, float, double, ShortDecimal, @@ -209,7 +210,7 @@ Array Functions SELECT array_sort(ARRAY [ARRAY [1, 2], ARRAY [1, null]]); -- failed: Ordering nulls is not supported .. function:: array_sort_desc(array(T), function(T,U)) -> array(T) - :noindex: + :noindex: Returns the array sorted by values computed using specified lambda in descending order. U must be an orderable type. Null elements will be placed at the end of @@ -217,7 +218,20 @@ Array Functions nested nulls. Throws if deciding the order of elements would require comparing nested null values. :: - SELECT array_sort_desc(ARRAY ['cat', 'leopard', 'mouse'], x -> length(x)); -- ['leopard', 'mouse', 'cat'] + SELECT array_sort_desc(ARRAY ['cat', 'leopard', 'mouse'], x -> length(x)); -- ['leopard', 'mouse', 'cat'] + +.. function:: array_subset(array(T), array(int)) -> array(T) + + Returns an array containing elements from the input array at the specified 1-based indices. + Indices must be positive integers. Invalid indices (out of bounds, zero, or negative) are ignored. + Null elements at valid indices are preserved in the output. Duplicate indices result in duplicate elements in the output. + The output maintains the order of the indices array. :: + + SELECT array_subset(ARRAY[1, 2, 3, 4, 5], ARRAY[1, 3, 5]); -- [1, 3, 5] + SELECT array_subset(ARRAY['a', 'b', 'c'], ARRAY[3, 1, 2]); -- ['c', 'a', 'b'] + SELECT array_subset(ARRAY[1, NULL, 3], ARRAY[2]); -- [NULL] + SELECT array_subset(ARRAY[1, 2, 3], ARRAY[1, 1, 2]); -- [1, 1, 2] + SELECT array_subset(ARRAY[1, 2, 3], ARRAY[5, 0, -1]); -- [] .. function:: array_sum(array(T)) -> bigint/double diff --git a/velox/docs/functions/presto/datetime.rst b/velox/docs/functions/presto/datetime.rst index 98e42dac552b..4f5802d98f01 100644 --- a/velox/docs/functions/presto/datetime.rst +++ b/velox/docs/functions/presto/datetime.rst @@ -144,7 +144,10 @@ Date and Time Functions .. function:: from_unixtime(unixtime) -> timestamp - Returns the UNIX timestamp ``unixtime`` as a timestamp. + Returns the UNIX timestamp ``unixtime`` as a timestamp. If the + :doc:`adjust_timestamp_to_session_timezone <../../configs>` property is set + to true, then the timestamp is adjusted to the time zone specified in + :doc:`session_timezone <../../configs>`. .. function:: from_unixtime(unixtime, string) -> timestamp with time zone :noindex: diff --git a/velox/docs/functions/presto/geospatial.rst b/velox/docs/functions/presto/geospatial.rst index 256a6cb825f4..9a612d58f1ec 100644 --- a/velox/docs/functions/presto/geospatial.rst +++ b/velox/docs/functions/presto/geospatial.rst @@ -73,6 +73,28 @@ Geometry Constructors Returns a geometry type polygon object from WKT representation. +.. function:: ST_LineFromText(wkt: varchar) -> linestring: Geometry + + Returns a geometry type linestring object from WKT representation. + An error is returned if the input WKT represents a valid non-LineString + geometry. Null input returns null output. + +.. function:: ST_LineString(points: array(Geometry)) -> linestring: Geometry + + Returns a LineString formed from an array of points. If there are fewer + than two non-empty points in the input array, an empty LineString will + be returned. Throws an exception if any element in the array is null or + empty or same as the previous one. The returned geometry may not be simple, + e.g. may self-intersect or may contain duplicate vertexes depending on the + input. + +.. function:: ST_MultiPoint(points: array(Geometry)) -> multipoint: Geometry + + Returns a MultiPoint geometry object formed from the specified points. + Return null if input array is empty. Throws an exception if any element + in the array is null or empty. The returned geometry may not be simple + and may contain duplicate points if input array has duplicates. + Spatial Predicates ------------------ @@ -114,7 +136,7 @@ function you are using. Returns ``true`` if the given geometries share space, are of the same dimension, but are not completely contained by each other. -.. function:: ST_Relat(geometry1: Geometry, geometry2: Geometry, relation: varchar) -> boolean +.. function:: ST_Relate(geometry1: Geometry, geometry2: Geometry, relation: varchar) -> boolean Returns true if first geometry is spatially related to second geometry as described by the relation. The relation is a string like ``'"1*T***T**'``: @@ -178,6 +200,13 @@ Spatial Operations Empty geometries will return an empty polygon. Negative or NaN distances will return an error. Positive infinity distances may lead to undefined results. +.. function:: geometry_union(geometries: array(Geometry)) -> union: Geometry + + Returns a geometry that represents the point set union of the input geometries. + Performance of this function, in conjunction with array_agg() to first + aggregate the input geometries, may be better than geometry_union_agg(), + at the expense of higher memory utilization. Null elements in the input + array are ignored. Empty array input returns null. Accessors --------- @@ -224,7 +253,7 @@ Accessors Returns an array of points in a geometry. Empty or null inputs return null. -.. function:: ST_NumPoints(geometry: Geometry) -> points: integer +.. function:: ST_NumPoints(geometry: Geometry) -> points: bigint Returns the number of points in a geometry. This is an extension to the SQL/MM ``ST_NumPoints`` function which only applies to @@ -513,5 +542,13 @@ for more details. given zoom level. Empty inputs return an empty array, and null inputs return null. +.. function:: geometry_to_dissolved_bing_tiles(geometry: Geometry, max_zoom_level: tinyint) -> tile: array(BingTile) + + Returns the minimum set of Bing tiles that fully covers a given geometry at a + given zoom level, recursively dissolving full sets of children into parents. + This results in a smaller array of tiles of different zoom levels. + For example, if the non-dissolved covering is [“00”, “01”, “02”, “03”, “10”], + the dissolved covering would be [“0”, “10”]. Zoom levels from 0 to 23 are supported. + .. _OpenGIS Specifications: https://www.ogc.org/standards/ogcapi-features/ .. _SQL/MM Part 3: Spatial: https://www.iso.org/standard/31369.html diff --git a/velox/docs/functions/presto/hyperloglog.rst b/velox/docs/functions/presto/hyperloglog.rst index ecd8e6d384ab..fd10739f5a53 100644 --- a/velox/docs/functions/presto/hyperloglog.rst +++ b/velox/docs/functions/presto/hyperloglog.rst @@ -70,3 +70,10 @@ Functions Returns the ``HyperLogLog`` of the aggregate union of the individual ``hll`` HyperLogLog structures. + +.. function:: merge_hll(array(HyperLogLog)) -> HyperLogLog + + Returns the ``HyperLogLog`` of the union of an array of ``HyperLogLog`` structures. + + * Returns ``NULL`` if the input array is ``NULL``, empty, or contains only ``NULL`` elements + * Ignores ``NULL`` elements and merges only valid ``HyperLogLog`` structures when the array contains a mix of ``NULL`` and non-null elements diff --git a/velox/docs/functions/presto/map.rst b/velox/docs/functions/presto/map.rst index c81d7eba948b..35476af0f8bc 100644 --- a/velox/docs/functions/presto/map.rst +++ b/velox/docs/functions/presto/map.rst @@ -94,6 +94,16 @@ Map Functions SELECT map_remove_null_values(MAP(ARRAY[1, 2, 3], ARRAY[3, 4, NULL])); -- {1=3, 2=4} SELECT map_remove_null_values(NULL); -- NULL +.. function:: remap_keys(map(K,V), array(K), array(K)) -> map(K,V) + + Returns a map with keys remapped according to the oldKeys and newKeys arrays. + Unmapped keys remain unchanged. Values are preserved. Null keys are ignored. :: + + SELECT remap_keys(MAP(ARRAY[1, 2, 3], ARRAY[10, 20, 30]), ARRAY[1, 3], ARRAY[100, 300]); -- {100 -> 10, 2 -> 20, 300 -> 30} + SELECT remap_keys(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), ARRAY['a', 'c'], ARRAY['alpha', 'charlie']); -- {alpha -> 1, b -> 2, charlie -> 3} + SELECT remap_keys(MAP(ARRAY[1, 2, 3], ARRAY[10, null, 30]), ARRAY[1, 2], ARRAY[100, 200]); -- {100 -> 10, 200 -> null, 3 -> 30} + SELECT remap_keys(MAP(ARRAY[1, 2], ARRAY[10, 20]), ARRAY[], ARRAY[]); -- {1 -> 10, 2 -> 20} + .. function:: map_subset(map(K,V), array(k)) -> map(K,V) Constructs a map from those entries of ``map`` for which the key is in the array given diff --git a/velox/docs/functions/spark/array.rst b/velox/docs/functions/spark/array.rst index 42dbfe54e21d..42039e6c636d 100644 --- a/velox/docs/functions/spark/array.rst +++ b/velox/docs/functions/spark/array.rst @@ -171,13 +171,35 @@ Array Functions .. spark:function:: array_sort(array(E)) -> array(E) Returns an array which has the sorted order of the input array(E). The elements of array(E) must - be orderable. Null elements will be placed at the end of the returned array. :: + be orderable. NULL and NaN elements will be placed at the end of the returned array, with NaN elements appearing before NULL elements for floating-point types. :: SELECT array_sort(array(1, 2, 3)); -- [1, 2, 3] SELECT array_sort(array(3, 2, 1)); -- [1, 2, 3] - SELECT array_sort(array(2, 1, NULL); -- [1, 2, NULL] + SELECT array_sort(array(2, 1, NULL)); -- [1, 2, NULL] SELECT array_sort(array(NULL, 1, NULL)); -- [1, NULL, NULL] SELECT array_sort(array(NULL, 2, 1)); -- [1, 2, NULL] + SELECT array_sort(array(4.0, NULL, float('nan'), 3.0)); -- [3.0, 4.0, NaN, NULL] + SELECT array_sort(array(array(), array(1, 3, NULL), array(NULL, 6), NULL, array(2, 1))); -- [[], [NULL, 6], [1, 3, NULL], [2, 1], NULL] + +.. spark:function:: array_sort(array(E), function(E,U)) -> array(E) + :noindex: + + Returns the array sorted by values computed using specified lambda in ascending order. ``U`` must be an orderable type. + NULL and NaN elements returned by the lambda function will be placed at the end of the returned array, with NaN elements appearing before NULL elements. + This function is not supported in Spark and is only used inside Velox for rewriting :spark:func:`array_sort(array(E), function(E,E,U)) -> array(E)` as :spark:func:`array_sort(array(E), function(E,U)) -> array(E)`. :: + +.. spark:function:: array_sort(array(E), function(E,E,U)) -> array(E) + :noindex: + + Returns the array sorted by values computed using specified lambda in ascending + order. ``U`` must be an orderable type. + The function attempts to analyze the lambda function and rewrite it into a simpler call that + specifies the sort-by expression (like :spark:func:`array_sort(array(E), function(E,U)) -> array(E)`). For example, ``(left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))`` will be rewritten to ``x -> length(x)``. If rewrite is not possible, a user error will be thrown. + If the rewritten function returns NULL, the corresponding element will be placed at the end the returned array. Please note that due to this rewrite optimization, the NULL handling logics between Spark and Velox differ. In Spark, the position of NULL element is determined by the comparison of NULL with other elements. :: + + SELECT array_sort(array('cat', 'leopard', 'mouse'), (left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))); -- ['cat', 'mouse', 'leopard'] + select array_sort(array("abcd123", "abcd", NULL, "abc"), (left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))); -- ["abc", "abcd", "abcd123", NULL] + select array_sort(array("abcd123", "abcd", NULL, "abc"), (left, right) -> if(length(left) > length(right), 1, if(length(left) = length(right), 0, -1))); -- ["abc", "abcd", "abcd123", NULL] different with Spark: ["abc", NULL, "abcd", "abcd123"] .. spark:function:: array_union(array(E) x, array(E) y) -> array(E) diff --git a/velox/docs/functions/spark/conversion.rst b/velox/docs/functions/spark/conversion.rst index 22a7da562c44..6ea90482f8cd 100644 --- a/velox/docs/functions/spark/conversion.rst +++ b/velox/docs/functions/spark/conversion.rst @@ -132,7 +132,7 @@ Valid examples SELECT cast(cast(2147483648.90 as DECIMAL(12, 2)) as bigint); -- 2147483648 From timestamp -^^^^^^^^^^^^^ +^^^^^^^^^^^^^^ Casting timestamp as integral types returns the number of seconds by converting timestamp as microseconds, dividing by the number of microseconds in a second, and then rounding down to the nearest second since the epoch (1970-01-01 00:00:00 UTC). diff --git a/velox/docs/functions/spark/datetime.rst b/velox/docs/functions/spark/datetime.rst index d87f5819885a..ed3ea63c9f58 100644 --- a/velox/docs/functions/spark/datetime.rst +++ b/velox/docs/functions/spark/datetime.rst @@ -244,6 +244,20 @@ These functions support TIMESTAMP and DATE input types. SELECT month('2009-07-30'); -- 7 +.. spark:function:: months_between(timestamp1, timestamp2, roundOff) -> double + + Returns number of months between times ``timestamp1`` and ``timestamp2``. + If ``timestamp1`` is later than ``timestamp2``, the result is positive. + If ``timestamp1`` and ``timestamp2`` are on the same day of month, or both are the + last day of month, time of day will be ignored. Otherwise, the difference is calculated + based on 31 days per month, and rounded to 8 digits unless ``roundOff`` is false. :: + + SELECT months_between('1997-02-28 10:30:00', '1996-10-30', true); -- 3.94959677 + SELECT months_between('1997-02-28 10:30:00', '1996-10-30', false); -- 3.9495967741935485 + SELECT months_between('1997-02-28 10:30:00', '1996-03-31 11:00:00', true); -- 11.0 + SELECT months_between('1997-02-28 10:30:00', '1996-03-28 11:00:00', true); -- 11.0 + SELECT months_between('1997-02-21 10:30:00', '1996-03-21 11:00:00', true); -- 11.0 + .. spark:function:: next_day(startDate, dayOfWeek) -> date Returns the first date which is later than ``startDate`` and named as ``dayOfWeek``. @@ -322,12 +336,14 @@ These functions support TIMESTAMP and DATE input types. converts the number of seconds to a timestamp. For floating-point types (FLOAT, DOUBLE), the function scales the input to microseconds, truncates towards zero, and saturates the result to the minimum and maximum values allowed - in Spark.:: + in Spark. Returns NULL when ``x`` is NaN or Infinity. :: SELECT timestamp_seconds(1230219000); -- '2008-12-25 15:30:00' SELECT timestamp_seconds(1230219000.123); -- '2008-12-25 15:30:00.123' SELECT timestamp_seconds(double(1.1234567)); -- '1970-01-01 00:00:01.123456' + SELECT timestamp_seconds(double('inf')); -- NULL SELECT timestamp_seconds(float(3.4028235E+38)); -- '+294247-01-10 04:00:54.775807' + SELECT timestamp_seconds(float('nan')); -- NULL .. spark:function:: to_unix_timestamp(date) -> bigint :noindex: diff --git a/velox/docs/functions/spark/math.rst b/velox/docs/functions/spark/math.rst index 08a11a2b60e0..3a637c828eea 100644 --- a/velox/docs/functions/spark/math.rst +++ b/velox/docs/functions/spark/math.rst @@ -2,12 +2,17 @@ Mathematical Functions ====================== -.. spark:function:: abs(x) -> [same as x] +.. spark:function:: abs(x) -> [same as x] (ANSI compliant) Returns the absolute value of ``x``. When ``x`` is negative minimum - value of integral type, returns the same value as ``x`` following - the behavior when Spark ANSI mode is disabled. - + value of integral type returns the same value as ``x`` following + the behavior when Spark ANSI mode is disabled and throws exception + when Spark ANSI mode is enabled. :: + + SELECT abs(-42); -- 42 + SELECT abs(3.14); -- 3.14 + SELECT abs(-128); -- 128 (with ANSI mode disabled) + SELECT abs(-128); -- Overflow exception (with ANSI mode enabled for TINYINT) .. spark:function:: acos(x) -> double Returns the inverse cosine (a.k.a. arc cosine) of ``x``. diff --git a/velox/docs/index.rst b/velox/docs/index.rst index 8fb3f14161b0..f1d65f2065e0 100644 --- a/velox/docs/index.rst +++ b/velox/docs/index.rst @@ -10,6 +10,7 @@ Velox Documentation functions spark_functions functions/iceberg/functions + functions/delta/functions configs monitoring bindings/python/index diff --git a/velox/docs/monitoring/metrics.rst b/velox/docs/monitoring/metrics.rst index 2764a1c46dd3..e34ee11dacca 100644 --- a/velox/docs/monitoring/metrics.rst +++ b/velox/docs/monitoring/metrics.rst @@ -292,9 +292,12 @@ Cache - Avg - Max possible age of AsyncDataCache and SsdCache entries since the raw file was opened to load the cache. - * - memory_cache_num_entries + * - memory_cache_num_large_entries - Avg - - Total number of cache entries. + - Total number of large cache entries. + * - memory_cache_num_tiny_entries + - Avg + - Total number of tiny cache entries. * - memory_cache_num_empty_entries - Avg - Total number of cache entries that do not cache anything. diff --git a/velox/docs/monthly-updates/may-2025.rst b/velox/docs/monthly-updates/may-2025.rst index fcacf0908535..d049377a9a52 100644 --- a/velox/docs/monthly-updates/may-2025.rst +++ b/velox/docs/monthly-updates/may-2025.rst @@ -1,6 +1,6 @@ -************** +*************** May 2025 Update -************** +*************** This update was generated with the assistance of AI. While we strive for accuracy, please note that AI-generated content may not always be error-free. We encourage you to verify any information diff --git a/velox/dwio/common/BufferedInput.h b/velox/dwio/common/BufferedInput.h index 1f877b3fa8d0..e2089c42a799 100644 --- a/velox/dwio/common/BufferedInput.h +++ b/velox/dwio/common/BufferedInput.h @@ -37,13 +37,15 @@ class BufferedInput { IoStatistics* stats = nullptr, filesystems::File::IoStats* fsStats = nullptr, uint64_t maxMergeDistance = kMaxMergeDistance, - std::optional wsVRLoad = std::nullopt) + std::optional wsVRLoad = std::nullopt, + folly::F14FastMap fileReadOps = {}) : BufferedInput( std::make_shared( std::move(readFile), metricsLog, stats, - fsStats), + fsStats, + std::move(fileReadOps)), pool, maxMergeDistance, wsVRLoad) {} diff --git a/velox/dwio/common/CMakeLists.txt b/velox/dwio/common/CMakeLists.txt index 3a4976bd625a..a2539e8458ec 100644 --- a/velox/dwio/common/CMakeLists.txt +++ b/velox/dwio/common/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(compression) add_subdirectory(encryption) add_subdirectory(exception) +add_subdirectory(wrap) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) @@ -45,6 +46,7 @@ velox_add_library( MetadataFilter.cpp Options.cpp OutputStream.cpp + ParallelUnitLoader.cpp ParallelFor.cpp Range.cpp Reader.cpp diff --git a/velox/dwio/common/CacheInputStream.cpp b/velox/dwio/common/CacheInputStream.cpp index dedda0c72f6a..99f09c86ca0c 100644 --- a/velox/dwio/common/CacheInputStream.cpp +++ b/velox/dwio/common/CacheInputStream.cpp @@ -394,7 +394,7 @@ velox::common::Region CacheInputStream::nextQuantizedLoadRegion( nextRegion.offset += (prevLoadedPosition / loadQuantum_) * loadQuantum_; // Set length to be the lesser of 'loadQuantum_' and distance to end of // 'region_' - nextRegion.length = std::min( + nextRegion.length = std::min( loadQuantum_, region_.length - (nextRegion.offset - region_.offset)); return nextRegion; } diff --git a/velox/dwio/common/CachedBufferedInput.h b/velox/dwio/common/CachedBufferedInput.h index ddb08061c9bd..4fb2475f5c85 100644 --- a/velox/dwio/common/CachedBufferedInput.h +++ b/velox/dwio/common/CachedBufferedInput.h @@ -27,8 +27,6 @@ #include "velox/dwio/common/CacheInputStream.h" #include "velox/dwio/common/InputStream.h" -DECLARE_int32(cache_load_quantum); - namespace facebook::velox::dwio::common { struct CacheRequest { @@ -64,13 +62,17 @@ class CachedBufferedInput : public BufferedInput { std::shared_ptr ioStats, std::shared_ptr fsStats, folly::Executor* executor, - const io::ReaderOptions& readerOptions) + const io::ReaderOptions& readerOptions, + folly::F14FastMap fileReadOps = {}) : BufferedInput( std::move(readFile), readerOptions.memoryPool(), metricsLog, ioStats.get(), - fsStats.get()), + fsStats.get(), + kMaxMergeDistance, + std::nullopt, + std::move(fileReadOps)), cache_(cache), fileNum_(std::move(fileNum)), tracker_(std::move(tracker)), diff --git a/velox/dwio/common/ColumnLoader.h b/velox/dwio/common/ColumnLoader.h index 50950b74e6e9..71d86c7e0aaf 100644 --- a/velox/dwio/common/ColumnLoader.h +++ b/velox/dwio/common/ColumnLoader.h @@ -17,6 +17,7 @@ #pragma once #include "velox/dwio/common/SelectiveStructColumnReader.h" +#include "velox/vector/LazyVector.h" namespace facebook::velox::dwio::common { @@ -30,11 +31,13 @@ class ColumnLoader : public VectorLoader { fieldReader_(fieldReader), version_(version) {} + virtual ~ColumnLoader() = default; + bool supportsHook() const override { return true; } - private: + protected: void loadInternal( RowSet rows, ValueHook* hook, diff --git a/velox/dwio/common/DirectBufferedInput.h b/velox/dwio/common/DirectBufferedInput.h index ea697c9e1755..74e29fe6f0cb 100644 --- a/velox/dwio/common/DirectBufferedInput.h +++ b/velox/dwio/common/DirectBufferedInput.h @@ -123,13 +123,17 @@ class DirectBufferedInput : public BufferedInput { std::shared_ptr ioStats, std::shared_ptr fsStats, folly::Executor* executor, - const io::ReaderOptions& readerOptions) + const io::ReaderOptions& readerOptions, + folly::F14FastMap fileReadOps = {}) : BufferedInput( std::move(readFile), readerOptions.memoryPool(), metricsLog, ioStats.get(), - fsStats.get()), + fsStats.get(), + kMaxMergeDistance, + std::nullopt, + std::move(fileReadOps)), fileNum_(std::move(fileNum)), tracker_(std::move(tracker)), groupId_(std::move(groupId)), diff --git a/velox/dwio/common/FlatMapHelper.cpp b/velox/dwio/common/FlatMapHelper.cpp index 5d48887af475..134f0cb1cfad 100644 --- a/velox/dwio/common/FlatMapHelper.cpp +++ b/velox/dwio/common/FlatMapHelper.cpp @@ -21,11 +21,20 @@ namespace facebook::velox::dwio::common::flatmap { namespace detail { -void reset(VectorPtr& vector, vector_size_t size, bool hasNulls) { +void reset( + VectorPtr& vector, + VectorEncoding::Simple desiredEncoding, + vector_size_t size, + bool hasNulls) { if (!vector) { return; } + if (vector->encoding() != desiredEncoding) { + vector.reset(); + return; + } + if (vector.use_count() > 1) { vector.reset(); return; @@ -162,7 +171,7 @@ void initializeVectorImpl( } } - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::ARRAY, size, hasNulls); VectorPtr origElementsVector; if (vector) { auto& arrayVector = dynamic_cast(*vector); @@ -226,7 +235,7 @@ void initializeMapVector( size = sizeOverride.value(); } - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::MAP, size, hasNulls); VectorPtr origKeysVector; VectorPtr origValuesVector; if (vector) { @@ -298,7 +307,7 @@ void initializeVectorImpl( } } - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::ROW, size, hasNulls); std::vector origChildren; if (vector) { auto& rowVector = dynamic_cast(*vector); diff --git a/velox/dwio/common/FlatMapHelper.h b/velox/dwio/common/FlatMapHelper.h index 05212cdd44ee..c9e73c072b05 100644 --- a/velox/dwio/common/FlatMapHelper.h +++ b/velox/dwio/common/FlatMapHelper.h @@ -24,7 +24,11 @@ namespace facebook::velox::dwio::common::flatmap { namespace detail { // Reset vector with the desired size/hasNulls properties -void reset(VectorPtr& vector, vector_size_t size, bool hasNulls); +void reset( + VectorPtr& vector, + VectorEncoding::Simple desiredEncoding, + vector_size_t size, + bool hasNulls); // Reset vector smart pointer if any of the buffers is not single referenced. template @@ -63,7 +67,7 @@ void initializeFlatVector( vector_size_t size, bool hasNulls, std::vector&& stringBuffers = {}) { - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::FLAT, size, hasNulls); if (vector) { auto& flatVector = dynamic_cast&>(*vector); detail::resetIfNotWritable(vector, flatVector.nulls(), flatVector.values()); diff --git a/velox/dwio/common/InputStream.cpp b/velox/dwio/common/InputStream.cpp index cc20a5fc55e5..c25a7b0be772 100644 --- a/velox/dwio/common/InputStream.cpp +++ b/velox/dwio/common/InputStream.cpp @@ -64,8 +64,10 @@ ReadFileInputStream::ReadFileInputStream( std::shared_ptr readFile, const MetricsLogPtr& metricsLog, IoStatistics* stats, - filesystems::File::IoStats* fsStats) + filesystems::File::IoStats* fsStats, + folly::F14FastMap fileReadOps) : InputStream(readFile->getName(), metricsLog, stats, fsStats), + fileStorageContext_(fsStats, std::move(fileReadOps)), readFile_(std::move(readFile)) {} void ReadFileInputStream::read( @@ -79,7 +81,7 @@ void ReadFileInputStream::read( std::string_view readData; { MicrosecondTimer timer(&readTimeUs); - readData = readFile_->pread(offset, length, buf, fsStats_); + readData = readFile_->pread(offset, length, buf, fileStorageContext_); } if (stats_) { stats_->incRawBytesRead(length); @@ -102,7 +104,7 @@ void ReadFileInputStream::read( LogType logType) { const int64_t bufferSize = totalBufferSize(buffers); logRead(offset, bufferSize, logType); - const auto size = readFile_->preadv(offset, buffers, fsStats_); + const auto size = readFile_->preadv(offset, buffers, fileStorageContext_); VELOX_CHECK_EQ( size, bufferSize, @@ -119,7 +121,7 @@ folly::SemiFuture ReadFileInputStream::readAsync( LogType logType) { const int64_t bufferSize = totalBufferSize(buffers); logRead(offset, bufferSize, logType); - return readFile_->preadvAsync(offset, buffers, fsStats_); + return readFile_->preadvAsync(offset, buffers, fileStorageContext_); } bool ReadFileInputStream::hasReadAsync() const { @@ -138,7 +140,7 @@ void ReadFileInputStream::vread( [&](size_t acc, const auto& r) { return acc + r.length; }); logRead(regions[0].offset, length, purpose); auto readStartMicros = getCurrentTimeMicro(); - readFile_->preadv(regions, iobufs, fsStats_); + readFile_->preadv(regions, iobufs, fileStorageContext_); if (stats_) { stats_->incRawBytesRead(length); stats_->incTotalScanTime((getCurrentTimeMicro() - readStartMicros) * 1000); diff --git a/velox/dwio/common/InputStream.h b/velox/dwio/common/InputStream.h index b0b6deb2c1f8..34dc948550c0 100644 --- a/velox/dwio/common/InputStream.h +++ b/velox/dwio/common/InputStream.h @@ -26,6 +26,7 @@ #include #include +#include #include "velox/common/file/File.h" #include "velox/common/file/Region.h" #include "velox/common/io/IoStatistics.h" @@ -143,7 +144,8 @@ class ReadFileInputStream final : public InputStream { std::shared_ptr, const MetricsLogPtr& metricsLog = MetricsLog::voidLog(), IoStatistics* stats = nullptr, - filesystems::File::IoStats* fsStats = nullptr); + filesystems::File::IoStats* fsStats = nullptr, + folly::F14FastMap fileReadOps = {}); ~ReadFileInputStream() override = default; @@ -179,6 +181,7 @@ class ReadFileInputStream final : public InputStream { } private: + FileStorageContext fileStorageContext_; std::shared_ptr readFile_; }; diff --git a/velox/dwio/common/MetadataFilter.cpp b/velox/dwio/common/MetadataFilter.cpp index 374a2d861082..59181c557fea 100644 --- a/velox/dwio/common/MetadataFilter.cpp +++ b/velox/dwio/common/MetadataFilter.cpp @@ -18,6 +18,7 @@ #include #include "velox/dwio/common/ScanSpec.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" namespace facebook::velox::common { @@ -67,86 +68,118 @@ class MetadataFilter::LeafNode : public Node { std::unique_ptr filter_; }; -struct MetadataFilter::AndNode : Node { +struct MetadataFilter::ConditionNode : Node { static std::unique_ptr create( - std::unique_ptr lhs, - std::unique_ptr rhs) { - if (!lhs) { - return rhs; - } - if (!rhs) { - return lhs; + bool conjuction, + std::vector> args); + + static std::unique_ptr fromExpression( + const std::vector& inputs, + core::ExpressionEvaluator* evaluator, + bool conjunction, + bool negated) { + conjunction = negated ? !conjunction : conjunction; + std::vector> args; + args.reserve(inputs.size()); + for (const auto& input : inputs) { + auto node = Node::fromExpression(*input, evaluator, negated); + if (node) { + args.push_back(std::move(node)); + } else if (!conjunction) { + return nullptr; + } } - return std::make_unique(std::move(lhs), std::move(rhs)); + return create(conjunction, std::move(args)); } - AndNode(std::unique_ptr lhs, std::unique_ptr rhs) - : lhs_(std::move(lhs)), rhs_(std::move(rhs)) {} + explicit ConditionNode(std::vector> args) + : args_{std::move(args)} {} - void addToScanSpec(ScanSpec& scanSpec) const override { - lhs_->addToScanSpec(scanSpec); - rhs_->addToScanSpec(scanSpec); - } - - uint64_t* eval(LeafResults& leafResults, int size) const override { - auto* l = lhs_->eval(leafResults, size); - auto* r = rhs_->eval(leafResults, size); - if (!l) { - return r; - } - if (!r) { - return l; + void addToScanSpec(ScanSpec& scanSpec) const final { + for (const auto& arg : args_) { + arg->addToScanSpec(scanSpec); } - bits::orBits(l, r, 0, size); - return l; } - std::string toString() const override { - return "and(" + lhs_->toString() + "," + rhs_->toString() + ")"; + protected: + std::string ToStringImpl(std::string_view prefix) const { + std::string result{prefix}; + for (size_t i = 0; i < args_.size(); ++i) { + if (i != 0) { + result += ","; + } + result += args_[i]->toString(); + } + result += ")"; + return result; } - private: - std::unique_ptr lhs_; - std::unique_ptr rhs_; + std::vector> args_; }; -struct MetadataFilter::OrNode : Node { - static std::unique_ptr create( - std::unique_ptr lhs, - std::unique_ptr rhs) { - if (!lhs || !rhs) { - return nullptr; +struct MetadataFilter::AndNode final : ConditionNode { + using ConditionNode::ConditionNode; + + uint64_t* eval(LeafResults& leafResults, int size) const final { + uint64_t* result = nullptr; + for (const auto& arg : args_) { + auto* a = arg->eval(leafResults, size); + if (!a) { + continue; + } + if (!result) { + result = a; + } else { + bits::orBits(result, a, 0, size); + } } - return std::make_unique(std::move(lhs), std::move(rhs)); + return result; } - OrNode(std::unique_ptr lhs, std::unique_ptr rhs) - : lhs_(std::move(lhs)), rhs_(std::move(rhs)) {} - - void addToScanSpec(ScanSpec& scanSpec) const override { - lhs_->addToScanSpec(scanSpec); - rhs_->addToScanSpec(scanSpec); + std::string toString() const final { + return ToStringImpl("and("); } +}; - uint64_t* eval(LeafResults& leafResults, int size) const override { - auto* l = lhs_->eval(leafResults, size); - auto* r = rhs_->eval(leafResults, size); - if (!l || !r) { - return nullptr; +struct MetadataFilter::OrNode final : ConditionNode { + using ConditionNode::ConditionNode; + + uint64_t* eval(LeafResults& leafResults, int size) const final { + uint64_t* result = nullptr; + for (const auto& arg : args_) { + auto* a = arg->eval(leafResults, size); + if (!a) { + return nullptr; + } + if (!result) { + result = a; + } else { + bits::andBits(result, a, 0, size); + } } - bits::andBits(l, r, 0, size); - return l; + return result; } - std::string toString() const override { - return "or(" + lhs_->toString() + "," + rhs_->toString() + ")"; + std::string toString() const final { + return ToStringImpl("or("); } - - private: - std::unique_ptr lhs_; - std::unique_ptr rhs_; }; +std::unique_ptr MetadataFilter::ConditionNode::create( + bool conjunction, + std::vector> args) { + if (args.empty()) { + return nullptr; + } + if (args.size() == 1) { + return std::move(args[0]); + } + if (conjunction) { + return std::make_unique(std::move(args)); + } + return std::make_unique(std::move(args)); +} + namespace { const core::CallTypedExpr* asCall(const core::ITypedExpr* expr) { @@ -163,17 +196,13 @@ std::unique_ptr MetadataFilter::Node::fromExpression( if (!call) { return nullptr; } - if (call->name() == "and") { - auto lhs = fromExpression(*call->inputs()[0], evaluator, negated); - auto rhs = fromExpression(*call->inputs()[1], evaluator, negated); - return negated ? OrNode::create(std::move(lhs), std::move(rhs)) - : AndNode::create(std::move(lhs), std::move(rhs)); - } - if (call->name() == "or") { - auto lhs = fromExpression(*call->inputs()[0], evaluator, negated); - auto rhs = fromExpression(*call->inputs()[1], evaluator, negated); - return negated ? AndNode::create(std::move(lhs), std::move(rhs)) - : OrNode::create(std::move(lhs), std::move(rhs)); + if (call->name() == expression::kAnd) { + return ConditionNode::fromExpression( + call->inputs(), evaluator, true, negated); + } + if (call->name() == expression::kOr) { + return ConditionNode::fromExpression( + call->inputs(), evaluator, false, negated); } if (call->name() == "not") { return fromExpression(*call->inputs()[0], evaluator, !negated); diff --git a/velox/dwio/common/MetadataFilter.h b/velox/dwio/common/MetadataFilter.h index 62b604b14407..d626bbdd9675 100644 --- a/velox/dwio/common/MetadataFilter.h +++ b/velox/dwio/common/MetadataFilter.h @@ -50,6 +50,7 @@ class MetadataFilter { private: struct Node; + struct ConditionNode; struct AndNode; struct OrNode; diff --git a/velox/dwio/common/OnDemandUnitLoader.cpp b/velox/dwio/common/OnDemandUnitLoader.cpp index d4ef4f0a5ef2..6a5616a31e53 100644 --- a/velox/dwio/common/OnDemandUnitLoader.cpp +++ b/velox/dwio/common/OnDemandUnitLoader.cpp @@ -15,12 +15,12 @@ */ #include "velox/dwio/common/OnDemandUnitLoader.h" +#include "velox/common/time/Timer.h" #include #include "velox/common/base/Exceptions.h" #include "velox/dwio/common/MeasureTime.h" -#include "velox/dwio/common/UnitLoaderTools.h" using facebook::velox::dwio::common::measureTimeIfCallback; @@ -42,6 +42,7 @@ class OnDemandUnitLoader : public UnitLoader { LoadUnit& getLoadedUnit(uint32_t unit) override { VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + processedUnits_.insert(unit); if (loadedUnit_.has_value()) { if (loadedUnit_.value() == unit) { return *loadUnits_[unit]; @@ -51,11 +52,14 @@ class OnDemandUnitLoader : public UnitLoader { loadedUnit_.reset(); } + uint64_t unitLoadNanos{0}; { + NanosecondTimer timer{&unitLoadNanos}; auto measure = measureTimeIfCallback(blockedOnIoCallback_); loadUnits_[unit]->load(); } loadedUnit_ = unit; + unitLoadNanos_ += unitLoadNanos; return *loadUnits_[unit]; } @@ -73,11 +77,28 @@ class OnDemandUnitLoader : public UnitLoader { rowOffsetInUnit, loadUnits_[unit]->getNumRows(), "Row out of range"); } + UnitLoaderStats stats() override { + UnitLoaderStats stats; + stats.addCounter("processedUnits", RuntimeCounter(processedUnits_.size())); + stats.addCounter( + "unitLoadNanos", + RuntimeCounter( + unitLoadNanos_ > std::numeric_limits::max() + ? std::numeric_limits::max() + : unitLoadNanos_, + RuntimeCounter::Unit::kNanos)); + return stats; + } + private: const std::vector> loadUnits_; const std::function blockedOnIoCallback_; std::optional loadedUnit_; + + // Stats + std::unordered_set processedUnits_; + uint64_t unitLoadNanos_{0}; }; } // namespace diff --git a/velox/dwio/common/Options.h b/velox/dwio/common/Options.h index ef70cb723794..b61f91fb8e2a 100644 --- a/velox/dwio/common/Options.h +++ b/velox/dwio/common/Options.h @@ -282,6 +282,22 @@ class RowReaderOptions { scanSpec_ = std::move(scanSpec); } + folly::Executor* ioExecutor() const { + return ioExecutor_; + } + + void setIOExecutor(folly::Executor* const ioExecutor) { + ioExecutor_ = ioExecutor; + } + + const size_t parallelUnitLoadCount() const { + return parallelUnitLoadCount_; + } + + void setParallelUnitLoadCount(size_t parallelUnitLoadCount) { + parallelUnitLoadCount_ = parallelUnitLoadCount; + } + const std::shared_ptr& metadataFilter() const { return metadataFilter_; } @@ -428,12 +444,21 @@ class RowReaderOptions { serdeParameters_ = std::move(serdeParameters); } + bool trackRowSize() const { + return trackRowSize_; + } + + void setTrackRowSize(bool value) { + trackRowSize_ = value; + } + private: uint64_t dataStart_; uint64_t dataLength_; bool preloadStripe_; bool projectSelectedType_; bool returnFlatVector_ = false; + size_t parallelUnitLoadCount_ = 0; ErrorTolerance errorTolerance_; std::shared_ptr selector_; RowTypePtr requestedType_; @@ -446,7 +471,8 @@ class RowReaderOptions { // Whether to generate FlatMapVectors when reading flat maps from the file. By // default, converts flat maps in the file to MapVectors. bool preserveFlatMapsInMemory_ = false; - + // Optional io executor to enable parallel unit loader. + folly::Executor* ioExecutor_; // Optional executors to enable internal reader parallelism. // 'decodingExecutor' allow parallelising the vector decoding process. // 'ioExecutor' enables parallelism when performing file system read @@ -485,6 +511,7 @@ class RowReaderOptions { TimestampPrecision timestampPrecision_ = TimestampPrecision::kMilliseconds; std::shared_ptr formatSpecificOptions_; + bool trackRowSize_{false}; }; /// Options for creating a Reader. diff --git a/velox/dwio/common/ParallelUnitLoader.cpp b/velox/dwio/common/ParallelUnitLoader.cpp new file mode 100644 index 000000000000..693f15548aef --- /dev/null +++ b/velox/dwio/common/ParallelUnitLoader.cpp @@ -0,0 +1,183 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/ParallelUnitLoader.h" +#include +#include "velox/common/base/Exceptions.h" +#include "velox/common/time/Timer.h" + +namespace facebook::velox::dwio::common { + +class ParallelUnitLoader : public UnitLoader { + public: + /// Enables concurrent loading of multiple units (stripes, row groups, etc.) + /// using asynchronous I/O to improve throughput and reduce read latency. + /// + /// **Loading Strategy:** + /// - Initialization: Preloads up to `maxConcurrentLoads` units concurrently + /// - Access pattern: On each getLoadedUnit() call, ensures the requested unit + /// is loaded and triggers loading of subsequent units within the window + /// - Memory management: Unloads all previous units to control memory usage + /// + /// **Performance Characteristics:** + /// - Best suited for sequential access patterns + /// - Memory usage: O(maxConcurrentLoads * average_unit_size) + /// - I/O parallelism: Up to `maxConcurrentLoads` concurrent load operations + /// + /// **Parameters:** + /// @param units All units to be loaded + /// @param ioExecutor Thread pool for asynchronous unit loading operations + /// @param maxConcurrentLoads Maximum units to load concurrently (sliding + /// window size) + /// + /// **Example with maxConcurrentLoads=3:** + /// ``` + /// Units: [0,1,2,3,4,5,6,7,8,9] + /// Init: Load [0,1,2] concurrently + /// Get(0): Wait for unit 0, trigger load of units [0,1,2], unload none + /// Get(1): Wait for unit 1, trigger load of units [1,2,3], unload [0] + /// Get(2): Wait for unit 2, trigger load of units [2,3,4], unload [0,1] + /// ``` + ParallelUnitLoader( + std::vector> units, + folly::Executor* ioExecutor, + uint16_t maxConcurrentLoads) + : loadUnits_(std::move(units)), + ioExecutor_(ioExecutor), + maxConcurrentLoads_(maxConcurrentLoads) { + VELOX_CHECK_NOT_NULL(ioExecutor, "ParallelUnitLoader ioExecutor is null"); + VELOX_CHECK_GT( + maxConcurrentLoads_, + 0, + "ParallelUnitLoader maxConcurrentLoads should be larger than 0"); + futures_.resize(loadUnits_.size()); + unloadFutures_.resize(loadUnits_.size()); + unitsLoaded_.resize(loadUnits_.size()); + } + + /// Destructor ensures all pending load operations are properly cancelled + /// and waited for to prevent resource leaks and dangling references. + ~ParallelUnitLoader() override { + for (auto& future : futures_) { + future.cancel(); + future.wait(); + } + } + + LoadUnit& getLoadedUnit(uint32_t unit) override { + VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + + processedUnits_.insert(unit); + // Ensure sliding window of units [unit, unit + maxConcurrentLoads_) is + // loading + for (size_t i = unit; + i < loadUnits_.size() && i < unit + maxConcurrentLoads_; + ++i) { + if (!unitsLoaded_[i]) { + load(i); + } + } + + uint64_t unitLoadNanos{0}; + try { + NanosecondTimer timer{&unitLoadNanos}; + futures_[unit].wait(); + } catch (const std::exception& e) { + VELOX_FAIL("Failed to load unit {}: {}", unit, e.what()); + } + waitForUnitReadyNanos_ += unitLoadNanos; + + // Unload the previous units + unloadUntil(unit); + + return *loadUnits_[unit]; + } + + void onRead(uint32_t unit, uint64_t rowOffsetInUnit, uint64_t /* rowCount */) + override { + VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + VELOX_CHECK_LT( + rowOffsetInUnit, loadUnits_[unit]->getNumRows(), "Row out of range"); + } + + void onSeek(uint32_t unit, uint64_t rowOffsetInUnit) override { + VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + VELOX_CHECK_LE( + rowOffsetInUnit, loadUnits_[unit]->getNumRows(), "Row out of range"); + } + + UnitLoaderStats stats() override { + UnitLoaderStats stats; + stats.addCounter("processedUnits", RuntimeCounter(processedUnits_.size())); + stats.addCounter( + "waitForUnitReadyNanos", + RuntimeCounter( + waitForUnitReadyNanos_ > std::numeric_limits::max() + ? std::numeric_limits::max() + : waitForUnitReadyNanos_, + RuntimeCounter::Unit::kNanos)); + return stats; + } + + private: + /// Submits the unit's load() to the I/O thread pool + void load(uint32_t unitIndex) { + VELOX_CHECK_LT(unitIndex, loadUnits_.size(), "Unit index out of bounds"); + VELOX_CHECK_NOT_NULL(ioExecutor_, "ParallelUnitLoader ioExecutor is null"); + + futures_[unitIndex] = folly::via( + ioExecutor_, [this, unitIndex]() { loadUnits_[unitIndex]->load(); }); + unitsLoaded_[unitIndex] = true; + } + + /// Unloads all the units before 'unitIndex' + void unloadUntil(uint32_t unitIndex) { + for (size_t i = 0; i < unitIndex; ++i) { + if (unitsLoaded_[i]) { + loadUnits_[i]->unload(); + unitsLoaded_[i] = false; + } + } + } + + std::vector unitsLoaded_; + std::vector> loadUnits_; + std::vector> futures_; + std::vector> unloadFutures_; + folly::Executor* ioExecutor_; + size_t maxConcurrentLoads_; + + // Stats + std::unordered_set processedUnits_; + uint64_t waitForUnitReadyNanos_{0}; +}; + +std::unique_ptr ParallelUnitLoaderFactory::create( + std::vector> loadUnits, + uint64_t rowsToSkip) { + const auto totalRows = std::accumulate( + loadUnits.cbegin(), loadUnits.cend(), 0UL, [](uint64_t sum, auto& unit) { + return sum + unit->getNumRows(); + }); + VELOX_CHECK_LE( + rowsToSkip, + totalRows, + "Can only skip up to the past-the-end row of the file."); + return std::make_unique( + std::move(loadUnits), ioExecutor_, maxConcurrentLoads_); +} + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/ParallelUnitLoader.h b/velox/dwio/common/ParallelUnitLoader.h new file mode 100644 index 000000000000..0ba89028326c --- /dev/null +++ b/velox/dwio/common/ParallelUnitLoader.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include "velox/dwio/common/UnitLoader.h" + +namespace facebook::velox::dwio::common { +class ParallelUnitLoaderFactory : public UnitLoaderFactory { + public: + ParallelUnitLoaderFactory( + folly::Executor* ioExecutor, + size_t maxConcurrentLoads) + : ioExecutor_(ioExecutor), maxConcurrentLoads_(maxConcurrentLoads) {} + + std::unique_ptr create( + std::vector> loadUnits, + uint64_t rowsToSkip) override; + + private: + folly::Executor* ioExecutor_; + size_t maxConcurrentLoads_; +}; + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/PositionProvider.h b/velox/dwio/common/PositionProvider.h index 7be3bc7a1602..c99655a87e7c 100644 --- a/velox/dwio/common/PositionProvider.h +++ b/velox/dwio/common/PositionProvider.h @@ -23,7 +23,7 @@ namespace facebook::velox::dwio::common { class PositionProvider { public: explicit PositionProvider(const std::vector& positions) - : position_{positions.begin()}, end_{positions.end()} {} + : position_{positions.cbegin()}, end_{positions.cend()} {} uint64_t next(); diff --git a/velox/dwio/common/ScanSpec.h b/velox/dwio/common/ScanSpec.h index fbcac3d3a591..7310c0de8b7a 100644 --- a/velox/dwio/common/ScanSpec.h +++ b/velox/dwio/common/ScanSpec.h @@ -149,7 +149,9 @@ class ScanSpec { } void setSubscript(int64_t subscript) { - subscript_ = subscript; + if (subscript_ != subscript) { + subscript_ = subscript; + } } // True if the value is returned from scan. A runtime pushdown of a filter diff --git a/velox/dwio/common/SelectiveStructColumnReader.cpp b/velox/dwio/common/SelectiveStructColumnReader.cpp index 7c0e9b208f9b..e9776575e11e 100644 --- a/velox/dwio/common/SelectiveStructColumnReader.cpp +++ b/velox/dwio/common/SelectiveStructColumnReader.cpp @@ -531,6 +531,12 @@ bool SelectiveStructColumnReaderBase::isChildMissing( childSpec.channel() >= fileType_->size()); } +std::unique_ptr +SelectiveStructColumnReaderBase::makeColumnLoader(vector_size_t index) { + return std::make_unique( + this, children_[index], numReads_); +} + void SelectiveStructColumnReaderBase::getValues( const RowSet& rows, VectorPtr* result) { @@ -616,7 +622,7 @@ void SelectiveStructColumnReaderBase::getValues( // LazyVector result. setOutputRowsForLazy(rows); setLazyField( - std::make_unique(this, children_[index], numReads_), + makeColumnLoader(index), resultRow->type()->childAt(channel), rows.size(), memoryPool_, diff --git a/velox/dwio/common/SelectiveStructColumnReader.h b/velox/dwio/common/SelectiveStructColumnReader.h index e867caa80eeb..4cbc4189a781 100644 --- a/velox/dwio/common/SelectiveStructColumnReader.h +++ b/velox/dwio/common/SelectiveStructColumnReader.h @@ -20,6 +20,8 @@ namespace facebook::velox::dwio::common { +class ColumnLoader; + template class SelectiveFlatMapColumnReaderHelper; @@ -161,6 +163,12 @@ class SelectiveStructColumnReaderBase : public SelectiveColumnReader { const int64_t offset, const int32_t rowsPerRowGroup); + virtual std::unique_ptr makeColumnLoader( + vector_size_t index); + + // Sequence number of output batch. Checked against ColumnLoaders + // created by 'this' to verify they are still valid at load. + uint64_t numReads_ = 0; std::vector children_; private: @@ -189,10 +197,6 @@ class SelectiveStructColumnReaderBase : public SelectiveColumnReader { // Dense set of rows to read in next(). raw_vector rows_; - // Sequence number of output batch. Checked against ColumnLoaders - // created by 'this' to verify they are still valid at load. - uint64_t numReads_ = 0; - int64_t lazyVectorReadOffset_; int64_t currentRowNumber_ = -1; diff --git a/velox/dwio/common/Statistics.h b/velox/dwio/common/Statistics.h index 1c6965d6d71c..6790fb92b331 100644 --- a/velox/dwio/common/Statistics.h +++ b/velox/dwio/common/Statistics.h @@ -18,6 +18,7 @@ #include #include +#include "velox/dwio/common/UnitLoader.h" #include "velox/common/base/Exceptions.h" #include "velox/common/base/RuntimeMetrics.h" @@ -536,6 +537,9 @@ struct ColumnReaderStatistics { // Number of rows returned by string dictionary reader that is flattened // instead of keeping dictionary encoding. int64_t flattenStringDictionaryValues{0}; + + // Total time spent in loading pages, in nanoseconds. + uint64_t pageLoadTimeNs{0}; }; struct RuntimeStatistics { @@ -558,39 +562,48 @@ struct RuntimeStatistics { int64_t numStripes{0}; + UnitLoaderStats unitLoaderStats; ColumnReaderStatistics columnReaderStatistics; - std::unordered_map toMap() { - std::unordered_map result; + std::unordered_map toRuntimeMetricMap() { + std::unordered_map result; + for (const auto& [name, metric] : unitLoaderStats.stats()) { + result.emplace(name, RuntimeMetric(metric.sum, metric.unit)); + } if (skippedSplits > 0) { - result.emplace("skippedSplits", RuntimeCounter(skippedSplits)); + result.emplace("skippedSplits", RuntimeMetric(skippedSplits)); } if (processedSplits > 0) { - result.emplace("processedSplits", RuntimeCounter(processedSplits)); + result.emplace("processedSplits", RuntimeMetric(processedSplits)); } if (skippedSplitBytes > 0) { result.emplace( "skippedSplitBytes", - RuntimeCounter(skippedSplitBytes, RuntimeCounter::Unit::kBytes)); + RuntimeMetric(skippedSplitBytes, RuntimeCounter::Unit::kBytes)); } if (skippedStrides > 0) { - result.emplace("skippedStrides", RuntimeCounter(skippedStrides)); + result.emplace("skippedStrides", RuntimeMetric(skippedStrides)); } if (processedStrides > 0) { - result.emplace("processedStrides", RuntimeCounter(processedStrides)); + result.emplace("processedStrides", RuntimeMetric(processedStrides)); } if (footerBufferOverread > 0) { result.emplace( "footerBufferOverread", - RuntimeCounter(footerBufferOverread, RuntimeCounter::Unit::kBytes)); + RuntimeMetric(footerBufferOverread, RuntimeCounter::Unit::kBytes)); } if (numStripes > 0) { - result.emplace("numStripes", RuntimeCounter(numStripes)); + result.emplace("numStripes", RuntimeMetric(numStripes)); } if (columnReaderStatistics.flattenStringDictionaryValues > 0) { result.emplace( "flattenStringDictionaryValues", - RuntimeCounter(columnReaderStatistics.flattenStringDictionaryValues)); + RuntimeMetric(columnReaderStatistics.flattenStringDictionaryValues)); + } + if (columnReaderStatistics.pageLoadTimeNs > 0) { + result.emplace( + "pageLoadTimeNs", + RuntimeMetric(columnReaderStatistics.pageLoadTimeNs)); } return result; } diff --git a/velox/dwio/common/UnitLoader.h b/velox/dwio/common/UnitLoader.h index d3125dacc4be..d1fc54ab2407 100644 --- a/velox/dwio/common/UnitLoader.h +++ b/velox/dwio/common/UnitLoader.h @@ -16,9 +16,13 @@ #pragma once +#include +#include #include #include #include +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/RuntimeMetrics.h" namespace facebook::velox::dwio::common { @@ -39,6 +43,44 @@ class LoadUnit { virtual uint64_t getIoSize() = 0; }; +class UnitLoaderStats { + public: + UnitLoaderStats() = default; + + void addCounter(const std::string& name, RuntimeCounter counter) { + auto locked = stats_.wlock(); + auto it = locked->find(name); + if (it == locked->end()) { + auto [ptr, inserted] = locked->emplace(name, RuntimeMetric(counter.unit)); + VELOX_CHECK(inserted); + ptr->second.addValue(counter.value); + } else { + VELOX_CHECK_EQ(it->second.unit, counter.unit); + it->second.addValue(counter.value); + } + } + + void merge(const UnitLoaderStats& other) { + auto otherStats = other.stats(); + auto locked = stats_.wlock(); + for (const auto& [name, metric] : otherStats) { + auto it = locked->find(name); + if (it == locked->end()) { + locked->emplace(name, metric); + } else { + it->second.merge(metric); + } + } + } + + folly::F14FastMap stats() const { + return stats_.copy(); + } + + private: + folly::Synchronized> stats_; +}; + class UnitLoader { public: virtual ~UnitLoader() = default; @@ -56,6 +98,10 @@ class UnitLoader { /// Reader reports seek calling this method. The call must be done **before** /// getLoadedUnit for the new unit. virtual void onSeek(uint32_t unit, uint64_t rowOffsetInUnit) = 0; + + virtual UnitLoaderStats stats() { + return UnitLoaderStats(); + }; }; class UnitLoaderFactory { diff --git a/velox/dwio/common/compression/Compression.cpp b/velox/dwio/common/compression/Compression.cpp index 222385b3ce47..1d571ce8c96f 100644 --- a/velox/dwio/common/compression/Compression.cpp +++ b/velox/dwio/common/compression/Compression.cpp @@ -555,6 +555,8 @@ bool ZlibDecompressionStream::readOrSkip(const void** data, int32_t* size) { *size = static_cast(availSize); outputBufferPtr_ = inputBufferPtr_ + availSize; outputBufferLength_ = 0; + inputBufferPtr_ += availSize; + remainingLength_ -= availSize; } else { DWIO_ENSURE_EQ( state_, @@ -567,42 +569,49 @@ bool ZlibDecompressionStream::readOrSkip(const void** data, int32_t* size) { getDecompressedLength(inputBufferPtr_, availSize).first); reset(); - zstream_.next_in = - reinterpret_cast(const_cast(inputBufferPtr_)); - zstream_.avail_in = folly::to(availSize); - outputBufferPtr_ = outputBuffer_->data(); - zstream_.next_out = - reinterpret_cast(const_cast(outputBufferPtr_)); - zstream_.avail_out = folly::to(blockSize_); int32_t result; + *size = 0; do { - result = inflate( - &zstream_, availSize == remainingLength_ ? Z_FINISH : Z_SYNC_FLUSH); - switch (result) { - case Z_OK: - remainingLength_ -= availSize; - inputBufferPtr_ += availSize; - readBuffer(true); - availSize = std::min( - static_cast(inputBufferPtrEnd_ - inputBufferPtr_), - remainingLength_); - zstream_.next_in = - reinterpret_cast(const_cast(inputBufferPtr_)); - zstream_.avail_in = static_cast(availSize); - break; - case Z_STREAM_END: - break; - default: - DWIO_RAISE( - "Error in ZlibDecompressionStream::Next in ", - getName(), - ". error: ", - result, - " Info: ", - ZlibDecompressor::streamDebugInfo_); + if (inputBufferPtr_ == inputBufferPtrEnd_) { + readBuffer(true); } + zstream_.next_in = + reinterpret_cast(const_cast(inputBufferPtr_)); + zstream_.avail_in = + static_cast(inputBufferPtrEnd_ - inputBufferPtr_); + + do { + // size_ of outputBuffer_ is not updated in inflate, so *size is used + // here to ensure enough capacity for the output data. + outputBuffer_->extend(*size); + outputBufferPtr_ = outputBuffer_->data(); + zstream_.next_out = reinterpret_cast( + const_cast(outputBufferPtr_ + *size)); + zstream_.avail_out = folly::to(blockSize_); + result = inflate(&zstream_, Z_SYNC_FLUSH); + // Result handling adapted from https://zlib.net/zlib_how.html + switch (result) { + case Z_NEED_DICT: + result = Z_DATA_ERROR; + [[fallthrough]]; + case Z_DATA_ERROR: + [[fallthrough]]; + case Z_MEM_ERROR: + [[fallthrough]]; + case Z_STREAM_ERROR: + DWIO_RAISE("Failed to inflate input data. error: ", result); + default: + *size += static_cast( + blockSize_ - static_cast(zstream_.avail_out)); + const size_t inputConsumed = + reinterpret_cast(zstream_.next_in) - + inputBufferPtr_; + remainingLength_ -= inputConsumed; + inputBufferPtr_ += inputConsumed; + } + } while (zstream_.avail_out == 0); } while (result != Z_STREAM_END); - *size = static_cast(blockSize_ - zstream_.avail_out); + if (data) { *data = outputBufferPtr_; } @@ -610,8 +619,6 @@ bool ZlibDecompressionStream::readOrSkip(const void** data, int32_t* size) { outputBufferPtr_ += *size; } - inputBufferPtr_ += availSize; - remainingLength_ -= availSize; bytesReturned_ += *size; return true; } diff --git a/velox/dwio/common/tests/CMakeLists.txt b/velox/dwio/common/tests/CMakeLists.txt index b65afce13672..cb6f356a31a7 100644 --- a/velox/dwio/common/tests/CMakeLists.txt +++ b/velox/dwio/common/tests/CMakeLists.txt @@ -38,6 +38,7 @@ add_executable( DecoderUtilTest.cpp ExecutorBarrierTest.cpp OnDemandUnitLoaderTests.cpp + ParallelUnitLoaderTest.cpp LocalFileSinkTest.cpp MemorySinkTest.cpp LoggedExceptionTest.cpp diff --git a/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp b/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp index 178ad21f9b2d..245c7d6186de 100644 --- a/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp +++ b/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp @@ -17,9 +17,8 @@ #include #include -#include "velox/common/base/tests/GTestUtils.h" #include "velox/dwio/common/OnDemandUnitLoader.h" -#include "velox/dwio/common/UnitLoaderTools.h" +#include "velox/dwio/common/tests/UnitLoaderBaseTest.h" #include "velox/dwio/common/tests/utils/UnitLoaderTestTools.h" using namespace ::testing; @@ -31,6 +30,38 @@ using facebook::velox::dwio::common::test::getUnitsLoadedWithFalse; using facebook::velox::dwio::common::test::LoadUnitMock; using facebook::velox::dwio::common::test::ReaderMock; +class OnDemandUnitLoaderCommonTests + : public UnitLoaderBaseTest { + protected: + OnDemandUnitLoaderFactory createFactory() override { + return OnDemandUnitLoaderFactory(nullptr); + } +}; + +TEST_F(OnDemandUnitLoaderCommonTests, NoUnitButSkip) { + testNoUnitButSkip(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, InitialSkip) { + testInitialSkip(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, CanRequestUnitMultipleTimes) { + testCanRequestUnitMultipleTimes(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, UnitOutOfRange) { + testUnitOutOfRange(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, SeekOutOfRange) { + testSeekOutOfRange(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, SeekOutOfRangeReaderError) { + testSeekOutOfRangeReaderError(); +} + TEST(OnDemandUnitLoaderTests, LoadsCorrectlyWithReader) { size_t blockedOnIoCount = 0; OnDemandUnitLoaderFactory factory([&](auto) { ++blockedOnIoCount; }); @@ -127,96 +158,3 @@ TEST(OnDemandUnitLoaderTests, CanSeek) { EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, false, false})); EXPECT_EQ(blockedOnIoCount, 4); } - -TEST(OnDemandUnitLoaderTests, SeekOutOfRangeReaderError) { - size_t blockedOnIoCount = 0; - OnDemandUnitLoaderFactory factory([&](auto) { ++blockedOnIoCount; }); - ReaderMock readerMock{{10, 20, 30}, {0, 0, 0}, factory, 0}; - EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, false, false})); - EXPECT_EQ(blockedOnIoCount, 0); - readerMock.seek(59); - - readerMock.seek(60); - - VELOX_ASSERT_THROW( - readerMock.seek(61), - "Can't seek to possition 61 in file. Must be up to 60."); -} - -TEST(OnDemandUnitLoaderTests, SeekOutOfRange) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - - auto unitLoader = factory.create(std::move(units), 0); - - unitLoader->onSeek(0, 10); - - VELOX_ASSERT_THROW(unitLoader->onSeek(0, 11), "Row out of range"); -} - -TEST(OnDemandUnitLoaderTests, UnitOutOfRange) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - - auto unitLoader = factory.create(std::move(units), 0); - unitLoader->getLoadedUnit(0); - - VELOX_ASSERT_THROW(unitLoader->getLoadedUnit(1), "Unit out of range"); -} - -TEST(OnDemandUnitLoaderTests, CanRequestUnitMultipleTimes) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - - auto unitLoader = factory.create(std::move(units), 0); - unitLoader->getLoadedUnit(0); - unitLoader->getLoadedUnit(0); - unitLoader->getLoadedUnit(0); -} - -TEST(OnDemandUnitLoaderTests, InitialSkip) { - auto getFactoryWithSkip = [](uint64_t skipToRow) { - auto factory = std::make_unique(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - units.push_back(std::make_unique(20, 0, unitsLoaded, 1)); - units.push_back(std::make_unique(30, 0, unitsLoaded, 2)); - factory->create(std::move(units), skipToRow); - }; - - EXPECT_NO_THROW(getFactoryWithSkip(0)); - EXPECT_NO_THROW(getFactoryWithSkip(1)); - EXPECT_NO_THROW(getFactoryWithSkip(9)); - EXPECT_NO_THROW(getFactoryWithSkip(10)); - EXPECT_NO_THROW(getFactoryWithSkip(11)); - EXPECT_NO_THROW(getFactoryWithSkip(29)); - EXPECT_NO_THROW(getFactoryWithSkip(30)); - EXPECT_NO_THROW(getFactoryWithSkip(31)); - EXPECT_NO_THROW(getFactoryWithSkip(59)); - EXPECT_NO_THROW(getFactoryWithSkip(60)); - VELOX_ASSERT_THROW( - getFactoryWithSkip(61), - "Can only skip up to the past-the-end row of the file."); - VELOX_ASSERT_THROW( - getFactoryWithSkip(100), - "Can only skip up to the past-the-end row of the file."); -} - -TEST(OnDemandUnitLoaderTests, NoUnitButSkip) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector> units; - - EXPECT_NO_THROW(factory.create(std::move(units), 0)); - - std::vector> units2; - VELOX_ASSERT_THROW( - factory.create(std::move(units2), 1), - "Can only skip up to the past-the-end row of the file."); -} diff --git a/velox/dwio/common/tests/ParallelUnitLoaderTest.cpp b/velox/dwio/common/tests/ParallelUnitLoaderTest.cpp new file mode 100644 index 000000000000..690acd9fc119 --- /dev/null +++ b/velox/dwio/common/tests/ParallelUnitLoaderTest.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/ParallelUnitLoader.h" +#include "velox/dwio/common/OnDemandUnitLoader.h" +#include "velox/dwio/common/tests/UnitLoaderBaseTest.h" +#include "velox/dwio/common/tests/utils/UnitLoaderTestTools.h" + +#include +#include +#include + +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::dwio::common::test; + +class ParallelUnitLoaderTest + : public UnitLoaderBaseTest { + protected: + ParallelUnitLoaderFactory createFactory() override { + return ParallelUnitLoaderFactory(ioExecutor_.get(), 2); + } + + std::unique_ptr ioExecutor_ = + std::make_unique(10); +}; + +TEST_F(ParallelUnitLoaderTest, NoUnitButSkip) { + testNoUnitButSkip(); +} + +TEST_F(ParallelUnitLoaderTest, InitialSkip) { + testInitialSkip(); +} + +TEST_F(ParallelUnitLoaderTest, CanRequestUnitMultipleTimes) { + testCanRequestUnitMultipleTimes(); +} + +TEST_F(ParallelUnitLoaderTest, UnitOutOfRange) { + testUnitOutOfRange(); +} + +TEST_F(ParallelUnitLoaderTest, SeekOutOfRange) { + testSeekOutOfRange(); +} + +TEST_F(ParallelUnitLoaderTest, SeekOutOfRangeReaderError) { + testSeekOutOfRangeReaderError(); +} + +TEST_F(ParallelUnitLoaderTest, LoadsCorrectlyWithReader) { + auto factory = createFactory(); + ReaderMock readerMock{{10, 20, 30}, {0, 0, 0}, factory, 0}; + + EXPECT_TRUE(readerMock.read(3)); // Unit: 0, rows: 0-2, load(0) + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, true, false})); + + EXPECT_TRUE(readerMock.read(3)); // Unit: 0, rows: 3-5 + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, true, false})); + + EXPECT_TRUE(readerMock.read(4)); // Unit: 0, rows: 6-9 + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, true, false})); + + EXPECT_TRUE(readerMock.read(14)); // Unit: 1, rows: 0-13, unload(0), load(1) + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, true, true})); + + // will only read 5 rows, no more rows in unit 1 + EXPECT_TRUE(readerMock.read(10)); // Unit: 1, rows: 14-19 + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, true, true})); + + EXPECT_TRUE(readerMock.read(30)); // Unit: 2, rows: 0-29, unload(1), load(2) + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, false, true})); + + EXPECT_FALSE(readerMock.read(30)); // No more data + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, false, true})); +} + +// Performance comparison test +TEST_F(ParallelUnitLoaderTest, PerformanceComparison) { + std::vector rowsPerUnit = {100, 100, 100, 100, 100, 100, 100, 100}; + std::vector ioSizes = { + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024}; + + // Measure ParallelUnitLoader performance + auto parallelStart = std::chrono::high_resolution_clock::now(); + { + auto factory = createFactory(); + ReaderMock reader(rowsPerUnit, ioSizes, factory, 0); + + for (size_t i = 0; i < rowsPerUnit.size(); ++i) { + uint64_t totalRowsRead = 0; + while (totalRowsRead < rowsPerUnit[i]) { + reader.read(25); + int nextRead = rowsPerUnit[i] - totalRowsRead; + totalRowsRead += std::min(25, nextRead); + } + } + } + auto parallelEnd = std::chrono::high_resolution_clock::now(); + + // Measure OnDemandUnitLoader performance + auto onDemandStart = std::chrono::high_resolution_clock::now(); + { + auto factory = std::make_shared(nullptr); + ReaderMock reader(rowsPerUnit, ioSizes, *factory, 0); + + for (size_t i = 0; i < rowsPerUnit.size(); ++i) { + uint64_t totalRowsRead = 0; + while (totalRowsRead < rowsPerUnit[i]) { + reader.read(25); + int nextRead = rowsPerUnit[i] - totalRowsRead; + totalRowsRead += std::min(25, nextRead); + } + } + } + auto onDemandEnd = std::chrono::high_resolution_clock::now(); + + auto parallelDuration = std::chrono::duration_cast( + parallelEnd - parallelStart); + auto onDemandDuration = std::chrono::duration_cast( + onDemandEnd - onDemandStart); + + // ParallelUnitLoader should be faster + EXPECT_GT(onDemandDuration.count(), parallelDuration.count()); +} diff --git a/velox/dwio/common/tests/TestBufferedInput.cpp b/velox/dwio/common/tests/TestBufferedInput.cpp index 6fa5be8da000..6c84719bae4a 100644 --- a/velox/dwio/common/tests/TestBufferedInput.cpp +++ b/velox/dwio/common/tests/TestBufferedInput.cpp @@ -35,7 +35,7 @@ class ReadFileMock : public ::facebook::velox::ReadFile { (uint64_t offset, uint64_t length, void* buf, - facebook::velox::filesystems::File::IoStats* stats), + (const facebook::velox::FileStorageContext&)fileStorageContext), (const, override)); MOCK_METHOD(bool, shouldCoalesce, (), (const, override)); @@ -48,7 +48,7 @@ class ReadFileMock : public ::facebook::velox::ReadFile { preadv, (folly::Range regions, folly::Range iobufs, - facebook::velox::filesystems::File::IoStats* stats), + (const facebook::velox::FileStorageContext&)fileStorageContext), (const, override)); }; @@ -60,14 +60,14 @@ void expectPreads( EXPECT_CALL(file, size()).WillRepeatedly(Return(content.size())); for (auto& read : reads) { ASSERT_GE(content.size(), read.offset + read.length); - EXPECT_CALL(file, pread(read.offset, read.length, _, nullptr)) + EXPECT_CALL(file, pread(read.offset, read.length, _, _)) .Times(1) .WillOnce( [content]( uint64_t offset, uint64_t length, void* buf, - facebook::velox::filesystems::File::IoStats* stats) + const facebook::velox::FileStorageContext& fileStorageContext) -> std::string_view { memcpy(buf, content.data() + offset, length); return {content.data() + offset, length}; @@ -81,13 +81,14 @@ void expectPreadvs( std::vector reads) { EXPECT_CALL(file, getName()).WillRepeatedly(Return("mock_name")); EXPECT_CALL(file, size()).WillRepeatedly(Return(content.size())); - EXPECT_CALL(file, preadv(_, _, nullptr)) + EXPECT_CALL(file, preadv(_, _, _)) .Times(1) .WillOnce( [content, reads]( folly::Range regions, folly::Range iobufs, - facebook::velox::filesystems::File::IoStats* stats) -> uint64_t { + const facebook::velox::FileStorageContext& fileStorageContext) + -> uint64_t { EXPECT_EQ(regions.size(), reads.size()); uint64_t length = 0; for (size_t i = 0; i < reads.size(); ++i) { diff --git a/velox/dwio/common/tests/UnitLoaderBaseTest.h b/velox/dwio/common/tests/UnitLoaderBaseTest.h new file mode 100644 index 000000000000..9faf9bc91b3e --- /dev/null +++ b/velox/dwio/common/tests/UnitLoaderBaseTest.h @@ -0,0 +1,140 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/common/UnitLoaderTools.h" +#include "velox/dwio/common/tests/utils/UnitLoaderTestTools.h" + +using facebook::velox::dwio::common::LoadUnit; +using facebook::velox::dwio::common::test::getUnitsLoadedWithFalse; +using facebook::velox::dwio::common::test::LoadUnitMock; +using facebook::velox::dwio::common::test::ReaderMock; + +/// Base test class template that provides common test functionality for +/// different UnitLoader implementations. This template class can be inherited +/// by specific test classes to get access to common test methods. Each derived +/// class should provide a createFactory() method that returns the appropriate +/// factory instance. +template +class UnitLoaderBaseTest : public ::testing::Test { + protected: + /// Factory method to create the appropriate UnitLoaderFactory instance. + /// This method should be implemented by derived classes. + virtual UnitLoaderFactoryType createFactory() = 0; + + /// Test that UnitLoader factory handles the case where no units exist but + /// skip is requested + void testNoUnitButSkip() { + UnitLoaderFactoryType factory = createFactory(); + std::vector> units; + + EXPECT_NO_THROW(factory.create(std::move(units), 0)); + + std::vector> units2; + VELOX_ASSERT_THROW( + factory.create(std::move(units2), 1), + "Can only skip up to the past-the-end row of the file."); + } + + /// Test that UnitLoader factory handles initial skip correctly for various + /// skip values + void testInitialSkip() { + auto getFactoryWithSkip = [this](uint64_t skipToRow) { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(3)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + units.push_back(std::make_unique(20, 0, unitsLoaded, 1)); + units.push_back(std::make_unique(30, 0, unitsLoaded, 2)); + factory.create(std::move(units), skipToRow); + }; + + EXPECT_NO_THROW(getFactoryWithSkip(0)); + EXPECT_NO_THROW(getFactoryWithSkip(1)); + EXPECT_NO_THROW(getFactoryWithSkip(9)); + EXPECT_NO_THROW(getFactoryWithSkip(10)); + EXPECT_NO_THROW(getFactoryWithSkip(11)); + EXPECT_NO_THROW(getFactoryWithSkip(29)); + EXPECT_NO_THROW(getFactoryWithSkip(30)); + EXPECT_NO_THROW(getFactoryWithSkip(31)); + EXPECT_NO_THROW(getFactoryWithSkip(59)); + EXPECT_NO_THROW(getFactoryWithSkip(60)); + VELOX_ASSERT_THROW( + getFactoryWithSkip(61), + "Can only skip up to the past-the-end row of the file."); + VELOX_ASSERT_THROW( + getFactoryWithSkip(100), + "Can only skip up to the past-the-end row of the file."); + } + + /// Test that the same unit can be requested multiple times without issues + void testCanRequestUnitMultipleTimes() { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + + auto unitLoader = factory.create(std::move(units), 0); + unitLoader->getLoadedUnit(0); + unitLoader->getLoadedUnit(0); + unitLoader->getLoadedUnit(0); + } + + /// Test that requesting a unit index out of range throws an exception + void testUnitOutOfRange() { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + + auto unitLoader = factory.create(std::move(units), 0); + unitLoader->getLoadedUnit(0); + + VELOX_ASSERT_THROW(unitLoader->getLoadedUnit(1), "Unit out of range"); + } + + /// Test that seeking out of range throws an exception + void testSeekOutOfRange() { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + + auto unitLoader = factory.create(std::move(units), 0); + + unitLoader->onSeek(0, 10); + + VELOX_ASSERT_THROW(unitLoader->onSeek(0, 11), "Row out of range"); + } + + /// Test that seeking out of range in ReaderMock throws appropriate exception + void testSeekOutOfRangeReaderError() { + auto factory = createFactory(); + ReaderMock readerMock{{10, 20, 30}, {0, 0, 0}, factory, 0}; + + readerMock.seek(59); + readerMock.seek(60); + + VELOX_ASSERT_THROW( + readerMock.seek(61), + "Can't seek to possition 61 in file. Must be up to 60."); + } +}; diff --git a/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp b/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp index f7656fbd08ca..09bbfde3e630 100644 --- a/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp +++ b/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp @@ -18,6 +18,7 @@ #include "velox/dwio/common/tests/utils/DataSetBuilder.h" #include "velox/expression/Expr.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/parse/Expressions.h" @@ -491,15 +492,31 @@ void E2EFilterTestBase::testMetadataFilterImpl( core::ExpressionEvaluator* evaluator, const std::string& remainingFilter, std::function validationFilter) { - SCOPED_TRACE(fmt::format("remainingFilter={}", remainingFilter)); + SCOPED_TRACE(fmt::format("remainingFilter='{}'", remainingFilter)); + auto untypedExpr = parse::parseExpr(remainingFilter, {}); + auto typedExpr = core::Expressions::inferTypes( + untypedExpr, batches[0]->type(), leafPool_.get()); + testMetadataFilterImpl( + batches, + std::move(filterField), + std::move(filter), + evaluator, + std::move(typedExpr), + std::move(validationFilter)); +} + +void E2EFilterTestBase::testMetadataFilterImpl( + const std::vector& batches, + common::Subfield filterField, + std::unique_ptr filter, + core::ExpressionEvaluator* evaluator, + core::TypedExprPtr typedExpr, + std::function validationFilter) { auto spec = std::make_shared(""); if (filter) { spec->getOrCreateChild(std::move(filterField)) ->setFilter(std::move(filter)); } - auto untypedExpr = parse::parseExpr(remainingFilter, {}); - auto typedExpr = core::Expressions::inferTypes( - untypedExpr, batches[0]->type(), leafPool_.get()); auto metadataFilter = std::make_shared(*spec, *typedExpr, evaluator); auto specA = spec->getOrCreateChild(common::Subfield("a")); @@ -621,6 +638,56 @@ void E2EFilterTestBase::testMetadataFilter() { [](int64_t a, int64_t) { return !!(a == 2 || a == 3 || a == 5 || a == 7); }); + { + SCOPED_TRACE("remainingFilter='a == 1 or a == 3 or a == 8'"); + auto typedExpr1 = core::Expressions::inferTypes( + parse::parseExpr("a == 1", {}), batches[0]->type(), leafPool_.get()); + auto typedExpr2 = core::Expressions::inferTypes( + parse::parseExpr("a == 3", {}), batches[0]->type(), leafPool_.get()); + auto typedExpr3 = core::Expressions::inferTypes( + parse::parseExpr("a == 8", {}), batches[0]->type(), leafPool_.get()); + + auto typedExpr = std::make_shared( + velox::BOOLEAN(), + std::vector{ + std::move(typedExpr1), + std::move(typedExpr2), + std::move(typedExpr3), + }, + expression::kOr); + testMetadataFilterImpl( + batches, + common::Subfield("a"), + nullptr, + &evaluator, + std::move(typedExpr), + [](int64_t a, int64_t) { return a == 1 || a == 3 || a == 8; }); + } + { + SCOPED_TRACE("remainingFilter='a >= 1 and a <= 100 and a == 8'"); + auto typedExpr1 = core::Expressions::inferTypes( + parse::parseExpr("a >= 1", {}), batches[0]->type(), leafPool_.get()); + auto typedExpr2 = core::Expressions::inferTypes( + parse::parseExpr("a <= 100", {}), batches[0]->type(), leafPool_.get()); + auto typedExpr3 = core::Expressions::inferTypes( + parse::parseExpr("b.c != 8", {}), batches[0]->type(), leafPool_.get()); + + auto typedExpr = std::make_shared( + velox::BOOLEAN(), + std::vector{ + std::move(typedExpr1), + std::move(typedExpr2), + std::move(typedExpr3), + }, + expression::kAnd); + testMetadataFilterImpl( + batches, + common::Subfield("a"), + nullptr, + &evaluator, + std::move(typedExpr), + [](int64_t a, int64_t c) { return a >= 1 && a <= 100 && c != 8; }); + } { SCOPED_TRACE("Values not unique in row group"); diff --git a/velox/dwio/common/tests/utils/E2EFilterTestBase.h b/velox/dwio/common/tests/utils/E2EFilterTestBase.h index 16d30e36a7b7..19ebfb6de659 100644 --- a/velox/dwio/common/tests/utils/E2EFilterTestBase.h +++ b/velox/dwio/common/tests/utils/E2EFilterTestBase.h @@ -336,6 +336,14 @@ class E2EFilterTestBase : public testing::Test { const std::string& remainingFilter, std::function validationFilter); + void testMetadataFilterImpl( + const std::vector& batches, + common::Subfield filterField, + std::unique_ptr filter, + core::ExpressionEvaluator* evaluator, + core::TypedExprPtr typedExpr, + std::function validationFilter); + protected: void testMetadataFilter(); diff --git a/velox/dwio/common/tests/utils/UnitLoaderTestTools.h b/velox/dwio/common/tests/utils/UnitLoaderTestTools.h index 9eae97f575c2..0d4018e4b4d6 100644 --- a/velox/dwio/common/tests/utils/UnitLoaderTestTools.h +++ b/velox/dwio/common/tests/utils/UnitLoaderTestTools.h @@ -32,16 +32,20 @@ class LoadUnitMock : public LoadUnit { uint64_t rowCount, uint64_t ioSize, std::vector& unitsLoaded, - size_t unitId) + size_t unitId, + std::chrono::milliseconds loadDelay = std::chrono::milliseconds(100)) : rowCount_{rowCount}, ioSize_{ioSize}, unitsLoaded_{unitsLoaded}, - unitId_{unitId} {} + unitId_{unitId}, + loadDelay_(loadDelay) {} ~LoadUnitMock() override = default; void load() override { VELOX_CHECK(!isLoaded()); + // Simulate loading time + std::this_thread::sleep_for(loadDelay_); unitsLoaded_[unitId_] = true; } @@ -67,6 +71,7 @@ class LoadUnitMock : public LoadUnit { uint64_t ioSize_; std::vector& unitsLoaded_; size_t unitId_; + std::chrono::milliseconds loadDelay_; }; class ReaderMock { diff --git a/velox/dwio/common/wrap/CMakeLists.txt b/velox/dwio/common/wrap/CMakeLists.txt new file mode 100644 index 000000000000..a598690b32e5 --- /dev/null +++ b/velox/dwio/common/wrap/CMakeLists.txt @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +velox_install_library_headers() diff --git a/velox/dwio/dwrf/CMakeLists.txt b/velox/dwio/dwrf/CMakeLists.txt index 613f69ad2f06..65188eb58d64 100644 --- a/velox/dwio/dwrf/CMakeLists.txt +++ b/velox/dwio/dwrf/CMakeLists.txt @@ -22,3 +22,5 @@ elseif(${VELOX_BUILD_TEST_UTILS}) endif() add_subdirectory(utils) add_subdirectory(writer) + +velox_install_library_headers() diff --git a/velox/dwio/dwrf/common/CMakeLists.txt b/velox/dwio/dwrf/common/CMakeLists.txt index 8a3dc1e393ef..3f8f38028eaa 100644 --- a/velox/dwio/dwrf/common/CMakeLists.txt +++ b/velox/dwio/dwrf/common/CMakeLists.txt @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +add_subdirectory(wrap) + velox_add_library( velox_dwio_dwrf_common ByteRLE.cpp diff --git a/velox/dwio/dwrf/common/Config.cpp b/velox/dwio/dwrf/common/Config.cpp index 8ce1ad5f2155..96fc64318b52 100644 --- a/velox/dwio/dwrf/common/Config.cpp +++ b/velox/dwio/dwrf/common/Config.cpp @@ -182,7 +182,7 @@ Config::Entry>> Config::Entry Config::MAP_FLAT_MAX_KEYS( "orc.map.flat.max.keys", - 20000); + 30000); Config::Entry Config::MAX_DICTIONARY_SIZE( "hive.exec.orc.max.dictionary.size", diff --git a/velox/dwio/dwrf/common/wrap/CMakeLists.txt b/velox/dwio/dwrf/common/wrap/CMakeLists.txt new file mode 100644 index 000000000000..a598690b32e5 --- /dev/null +++ b/velox/dwio/dwrf/common/wrap/CMakeLists.txt @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +velox_install_library_headers() diff --git a/velox/dwio/dwrf/reader/DwrfReader.cpp b/velox/dwio/dwrf/reader/DwrfReader.cpp index 4f0ca050d58d..4e41ce0fe88c 100644 --- a/velox/dwio/dwrf/reader/DwrfReader.cpp +++ b/velox/dwio/dwrf/reader/DwrfReader.cpp @@ -19,6 +19,7 @@ #include #include "velox/dwio/common/OnDemandUnitLoader.h" +#include "velox/dwio/common/ParallelUnitLoader.h" #include "velox/dwio/common/TypeUtils.h" #include "velox/dwio/common/exception/Exception.h" #include "velox/dwio/dwrf/reader/ColumnReader.h" @@ -341,9 +342,16 @@ std::unique_ptr DwrfRowReader::getUnitLoader() { std::shared_ptr unitLoaderFactory = options_.unitLoaderFactory(); if (!unitLoaderFactory) { - unitLoaderFactory = - std::make_shared( - options_.blockedOnIoCallback()); + if (loadUnits.size() > 1 && options_.parallelUnitLoadCount() > 1 && + options_.ioExecutor() != nullptr) { + unitLoaderFactory = + std::make_shared( + options_.ioExecutor(), options_.parallelUnitLoadCount()); + } else { + unitLoaderFactory = + std::make_shared( + options_.blockedOnIoCallback()); + } } return unitLoaderFactory->create(std::move(loadUnits), 0); } @@ -632,6 +640,8 @@ uint64_t DwrfRowReader::next( } else { previousRow_ = 0; } + // Collect unit loader stats at the end. + unitLoadStats_ = unitLoader_->stats(); return 0; } @@ -776,15 +786,23 @@ std::optional DwrfRowReader::estimatedRowSizeHelper( } std::optional DwrfRowReader::estimatedRowSize() const { + if (hasRowEstimate_) { + return estimatedRowSize_; + } + const auto& reader = getReader(); const auto& fileFooter = reader.footer(); + hasRowEstimate_ = true; + if (!fileFooter.hasNumberOfRows()) { - return std::nullopt; + estimatedRowSize_ = std::nullopt; + return estimatedRowSize_; } if (fileFooter.numberOfRows() < 1) { - return 0; + estimatedRowSize_ = 0; + return estimatedRowSize_; } // Estimate with projections. @@ -793,9 +811,12 @@ std::optional DwrfRowReader::estimatedRowSize() const { const auto projectedSize = estimatedRowSizeHelper(fileFooter, *stats, ROOT_NODE_ID); if (projectedSize.has_value()) { - return projectedSize.value() / fileFooter.numberOfRows(); + estimatedRowSize_ = projectedSize.value() / fileFooter.numberOfRows(); + return estimatedRowSize_; } - return std::nullopt; + + estimatedRowSize_ = std::nullopt; + return estimatedRowSize_; } DwrfReader::DwrfReader( diff --git a/velox/dwio/dwrf/reader/DwrfReader.h b/velox/dwio/dwrf/reader/DwrfReader.h index dcb38dbb5a1d..ce4b8dd6b6b7 100644 --- a/velox/dwio/dwrf/reader/DwrfReader.h +++ b/velox/dwio/dwrf/reader/DwrfReader.h @@ -113,6 +113,7 @@ class DwrfRowReader : public StrideIndexProvider, stats.numStripes += stripeCeiling_ - firstStripe_; stats.columnReaderStatistics.flattenStringDictionaryValues += columnReaderStatistics_.flattenStringDictionaryValues; + stats.unitLoaderStats.merge(unitLoadStats_); } void resetFilterCaches() override; @@ -210,6 +211,8 @@ class DwrfRowReader : public StrideIndexProvider, // Number of processed strides. int64_t processedStrides_{0}; + dwio::common::UnitLoaderStats unitLoadStats_; + // Set to true after clearing filter caches, i.e. adding a dynamic filter. // Causes filters to be re-evaluated against stride stats on next stride // instead of next stripe. @@ -221,6 +224,9 @@ class DwrfRowReader : public StrideIndexProvider, std::unique_ptr unitLoader_; DwrfUnit* currentUnit_; + + mutable std::optional estimatedRowSize_; + mutable bool hasRowEstimate_{false}; }; class DwrfReader : public dwio::common::Reader { diff --git a/velox/dwio/dwrf/test/TestDecompression.cpp b/velox/dwio/dwrf/test/TestDecompression.cpp index 3dcf5f68e46b..429b8e015286 100644 --- a/velox/dwio/dwrf/test/TestDecompression.cpp +++ b/velox/dwio/dwrf/test/TestDecompression.cpp @@ -552,6 +552,25 @@ TEST_F(DecompressionTest, testInflate) { } } +TEST_F(DecompressionTest, testSmallBufferInflate) { + const unsigned char buffer[] = { + 0xe, 0x0, 0x0, 0x63, 0x60, 0x64, 0x62, 0xc0, 0x8d, 0x0}; + const std::unique_ptr result = createTestDecompressor( + CompressionKind_ZLIB, + std::make_unique(buffer, std::size(buffer)), + 1 // blockSize 1 to test multiple inflate calls during decompression. + ); + const void* ptr; + int32_t length; + ASSERT_EQ(true, result->Next(&ptr, &length)); + ASSERT_EQ(30, length); + for (int32_t i = 0; i < 10; ++i) { + for (int32_t j = 0; j < 3; ++j) { + EXPECT_EQ(j, static_cast(ptr)[i * 3 + j]); + } + } +} + TEST_F(DecompressionTest, testInflateSequence) { const unsigned char buffer[] = {0xe, 0x0, 0x0, 0x63, 0x60, 0x64, 0x62, 0xc0, 0x8d, 0x0, 0xe, 0x0, 0x0, 0x63, diff --git a/velox/dwio/dwrf/test/TestReadFile.h b/velox/dwio/dwrf/test/TestReadFile.h index 8501b231ed84..e47d83224182 100644 --- a/velox/dwio/dwrf/test/TestReadFile.h +++ b/velox/dwio/dwrf/test/TestReadFile.h @@ -42,15 +42,15 @@ class TestReadFile : public velox::ReadFile { uint64_t offset, uint64_t length, void* buffer, - filesystems::File::IoStats* stats = nullptr) const override { + const FileStorageContext& fileStorageContext = {}) const override { const uint64_t content = offset + seed_; const uint64_t available = std::min(length_ - offset, length); int fill; for (fill = 0; fill < available; ++fill) { reinterpret_cast(buffer)[fill] = content + fill; } - if (stats) { - stats->addCounter( + if (fileStorageContext.ioStats) { + fileStorageContext.ioStats->addCounter( "read", RuntimeCounter(fill, RuntimeCounter::Unit::kBytes)); } return std::string_view(static_cast(buffer), fill); @@ -59,10 +59,10 @@ class TestReadFile : public velox::ReadFile { uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override { - auto res = ReadFile::preadv(offset, buffers, stats); - if (stats) { - stats->addCounter( + const FileStorageContext& fileStorageContext = {}) const override { + auto res = ReadFile::preadv(offset, buffers, fileStorageContext); + if (fileStorageContext.ioStats) { + fileStorageContext.ioStats->addCounter( "read", RuntimeCounter( static_cast(res), RuntimeCounter::Unit::kBytes)); diff --git a/velox/dwio/dwrf/test/TestStripeStream.cpp b/velox/dwio/dwrf/test/TestStripeStream.cpp index 2e9132aeb5fb..57cb796bc26d 100644 --- a/velox/dwio/dwrf/test/TestStripeStream.cpp +++ b/velox/dwio/dwrf/test/TestStripeStream.cpp @@ -42,8 +42,8 @@ class RecordingInputStream : public facebook::velox::InMemoryReadFile { uint64_t offset, uint64_t length, void* buf, - facebook::velox::filesystems::File::IoStats* stats = - nullptr) const override { + const facebook::velox::FileStorageContext& fileStorageContext = {}) + const override { reads_.push_back({offset, length}); return {static_cast(buf), length}; } diff --git a/velox/dwio/dwrf/writer/StatisticsBuilder.h b/velox/dwio/dwrf/writer/StatisticsBuilder.h index 3441a5ddf9a1..6311acf4de00 100644 --- a/velox/dwio/dwrf/writer/StatisticsBuilder.h +++ b/velox/dwio/dwrf/writer/StatisticsBuilder.h @@ -198,13 +198,13 @@ class StatisticsBuilder : public virtual dwio::common::ColumnStatistics { rawSize_ = 0; size_ = options_.initialSize; if (options_.countDistincts) { - hll_ = std::make_shared(options_.allocator); + hll_ = std::make_shared>(options_.allocator); } } protected: StatisticsBuilderOptions options_; - std::shared_ptr hll_; + std::shared_ptr> hll_; }; class BooleanStatisticsBuilder : public StatisticsBuilder, diff --git a/velox/dwio/parquet/reader/PageReader.cpp b/velox/dwio/parquet/reader/PageReader.cpp index 879a97a51f3d..077ee3791aa2 100644 --- a/velox/dwio/parquet/reader/PageReader.cpp +++ b/velox/dwio/parquet/reader/PageReader.cpp @@ -17,11 +17,11 @@ #include "velox/dwio/parquet/reader/PageReader.h" #include "velox/common/testutil/TestValue.h" +#include "velox/common/time/Timer.h" #include "velox/dwio/common/BufferUtil.h" #include "velox/dwio/common/ColumnVisitors.h" #include "velox/dwio/parquet/common/LevelConversion.h" #include "velox/dwio/parquet/thrift/ThriftTransport.h" - #include "velox/vector/FlatVector.h" #include // @manual @@ -87,7 +87,12 @@ PageHeader PageReader::readPageHeader() { if (bufferEnd_ == bufferStart_) { const void* buffer; int32_t size; - inputStream_->Next(&buffer, &size); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + inputStream_->Next(&buffer, &size); + } + stats_.pageLoadTimeNs += readUs * 1'000; bufferStart_ = reinterpret_cast(buffer); bufferEnd_ = bufferStart_ + size; } @@ -106,26 +111,31 @@ PageHeader PageReader::readPageHeader() { } const char* PageReader::readBytes(int32_t size, BufferPtr& copy) { - if (bufferEnd_ == bufferStart_) { - const void* buffer = nullptr; - int32_t bufferSize = 0; - if (!inputStream_->Next(&buffer, &bufferSize)) { - VELOX_FAIL("Read past end"); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + if (bufferEnd_ == bufferStart_) { + const void* buffer = nullptr; + int32_t bufferSize = 0; + if (!inputStream_->Next(&buffer, &bufferSize)) { + VELOX_FAIL("Read past end"); + } + bufferStart_ = reinterpret_cast(buffer); + bufferEnd_ = bufferStart_ + bufferSize; } - bufferStart_ = reinterpret_cast(buffer); - bufferEnd_ = bufferStart_ + bufferSize; - } - if (bufferEnd_ - bufferStart_ >= size) { - bufferStart_ += size; - return bufferStart_ - size; - } - dwio::common::ensureCapacity(copy, size, &pool_); - dwio::common::readBytes( - size, - inputStream_.get(), - copy->asMutable(), - bufferStart_, - bufferEnd_); + if (bufferEnd_ - bufferStart_ >= size) { + bufferStart_ += size; + return bufferStart_ - size; + } + dwio::common::ensureCapacity(copy, size, &pool_); + dwio::common::readBytes( + size, + inputStream_.get(), + copy->asMutable(), + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs += readUs * 1'000; return copy->as(); } @@ -368,12 +378,17 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { if (pageData_) { memcpy(dictionary_.values->asMutable(), pageData_, numBytes); } else { - dwio::common::readBytes( - numBytes, - inputStream_.get(), - dictionary_.values->asMutable(), - bufferStart_, - bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numBytes, + inputStream_.get(), + dictionary_.values->asMutable(), + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs += readUs * 1'000; } if (type_->type()->isShortDecimal() && parquetType == thrift::Type::INT32) { @@ -403,12 +418,17 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { if (pageData_) { memcpy(dictionary_.values->asMutable(), pageData_, numBytes); } else { - dwio::common::readBytes( - numBytes, - inputStream_.get(), - dictionary_.values->asMutable(), - bufferStart_, - bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numBytes, + inputStream_.get(), + dictionary_.values->asMutable(), + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs += readUs * 1'000; } // Expand the Parquet type length values to Velox type length. // We start from the end to allow in-place expansion. @@ -435,8 +455,13 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { if (pageData_) { memcpy(strings, pageData_, numBytes); } else { - dwio::common::readBytes( - numBytes, inputStream_.get(), strings, bufferStart_, bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numBytes, inputStream_.get(), strings, bufferStart_, bufferEnd_); + } + stats_.pageLoadTimeNs += readUs * 1'000; } auto header = strings; for (auto i = 0; i < dictionary_.numValues; ++i) { @@ -458,12 +483,17 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { if (pageData_) { memcpy(data, pageData_, numParquetBytes); } else { - dwio::common::readBytes( - numParquetBytes, - inputStream_.get(), - data, - bufferStart_, - bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numParquetBytes, + inputStream_.get(), + data, + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs += readUs * 1'000; } if (type_->type()->isShortDecimal()) { // Parquet decimal values have a fixed typeLength_ and are in big-endian diff --git a/velox/dwio/parquet/reader/PageReader.h b/velox/dwio/parquet/reader/PageReader.h index c377100428a7..c0d34fa968cd 100644 --- a/velox/dwio/parquet/reader/PageReader.h +++ b/velox/dwio/parquet/reader/PageReader.h @@ -42,6 +42,7 @@ class PageReader { ParquetTypeWithIdPtr fileType, common::CompressionKind codec, int64_t chunkSize, + dwio::common::ColumnReaderStatistics& stats, const tz::TimeZone* sessionTimezone) : pool_(pool), inputStream_(std::move(stream)), @@ -52,6 +53,7 @@ class PageReader { codec_(codec), chunkSize_(chunkSize), nullConcatenation_(pool_), + stats_(stats), sessionTimezone_(sessionTimezone) { type_->makeLevelInfo(leafInfo_); } @@ -62,6 +64,7 @@ class PageReader { memory::MemoryPool& pool, common::CompressionKind codec, int64_t chunkSize, + dwio::common::ColumnReaderStatistics& stats, const tz::TimeZone* sessionTimezone = nullptr) : pool_(pool), inputStream_(std::move(stream)), @@ -71,6 +74,7 @@ class PageReader { codec_(codec), chunkSize_(chunkSize), nullConcatenation_(pool_), + stats_(stats), sessionTimezone_(sessionTimezone) {} /// Advances 'numRows' top level rows. @@ -502,6 +506,8 @@ class PageReader { // Base values of dictionary when reading a string dictionary. VectorPtr dictionaryValues_; + dwio::common::ColumnReaderStatistics& stats_; + const tz::TimeZone* sessionTimezone_{nullptr}; // Decoders. Only one will be set at a time. diff --git a/velox/dwio/parquet/reader/ParquetColumnReader.cpp b/velox/dwio/parquet/reader/ParquetColumnReader.cpp index 0b69f446280d..e98159b7a46e 100644 --- a/velox/dwio/parquet/reader/ParquetColumnReader.cpp +++ b/velox/dwio/parquet/reader/ParquetColumnReader.cpp @@ -65,9 +65,11 @@ std::unique_ptr ParquetColumnReader::build( case TypeKind::VARCHAR: return std::make_unique(fileType, params, scanSpec); - case TypeKind::ARRAY: + case TypeKind::ARRAY: { + VELOX_CHECK(requestedType->isArray(), "Requested type must be array"); return std::make_unique( columnReaderOptions, requestedType, fileType, params, scanSpec); + } case TypeKind::MAP: return std::make_unique( diff --git a/velox/dwio/parquet/reader/ParquetData.cpp b/velox/dwio/parquet/reader/ParquetData.cpp index 788d04e39621..572c53acc2db 100644 --- a/velox/dwio/parquet/reader/ParquetData.cpp +++ b/velox/dwio/parquet/reader/ParquetData.cpp @@ -25,7 +25,7 @@ std::unique_ptr ParquetParams::toFormatData( const std::shared_ptr& type, const common::ScanSpec& /*scanSpec*/) { return std::make_unique( - type, metaData_, pool(), sessionTimezone_); + type, metaData_, pool(), runtimeStatistics(), sessionTimezone_); } void ParquetData::filterRowGroups( @@ -128,6 +128,7 @@ dwio::common::PositionProvider ParquetData::seekToRowGroup(int64_t index) { type_, metadata.compression(), metadata.totalCompressedSize(), + stats_, sessionTimezone_); return dwio::common::PositionProvider(empty); } diff --git a/velox/dwio/parquet/reader/ParquetData.h b/velox/dwio/parquet/reader/ParquetData.h index 1ea4a1e8c774..9926202491d6 100644 --- a/velox/dwio/parquet/reader/ParquetData.h +++ b/velox/dwio/parquet/reader/ParquetData.h @@ -63,6 +63,7 @@ class ParquetData : public dwio::common::FormatData { const std::shared_ptr& type, const FileMetaDataPtr fileMetadataPtr, memory::MemoryPool& pool, + dwio::common::ColumnReaderStatistics& stats, const tz::TimeZone* sessionTimezone) : pool_(pool), type_(std::static_pointer_cast(type)), @@ -70,6 +71,7 @@ class ParquetData : public dwio::common::FormatData { maxDefine_(type_->maxDefine_), maxRepeat_(type_->maxRepeat_), rowsInRowGroup_(-1), + stats_(stats), sessionTimezone_(sessionTimezone) {} /// Prepares to read data for 'index'th row group. @@ -90,8 +92,9 @@ class ParquetData : public dwio::common::FormatData { return reader_.get(); } - // Reads null flags for 'numValues' next top level rows. The first 'numValues' - // bits of 'nulls' are set and the reader is advanced by numValues'. + // Reads null flags for 'numValues' next top level rows. The first + // 'numValues' bits of 'nulls' are set and the reader is advanced by + // numValues'. void readNullsOnly(int32_t numValues, BufferPtr& nulls) { reader_->readNullsOnly(numValues, nulls); } @@ -100,8 +103,9 @@ class ParquetData : public dwio::common::FormatData { return maxDefine_ > 0; } - /// Sets nulls to be returned by readNulls(). Nulls for non-leaf readers come - /// from leaf repdefs which are gathered before descending the reader tree. + /// Sets nulls to be returned by readNulls(). Nulls for non-leaf readers + /// come from leaf repdefs which are gathered before descending the reader + /// tree. void setNulls(BufferPtr& nulls, int32_t numValues) { if (nulls || numValues) { VELOX_CHECK_EQ(presetNullsConsumed_, presetNullsSize_); @@ -120,8 +124,8 @@ class ParquetData : public dwio::common::FormatData { const uint64_t* incomingNulls, BufferPtr& nulls, bool nullsOnly = false) override { - // If the query accesses only nulls, read the nulls from the pages in range. - // If nulls are preread, return those minus any skipped. + // If the query accesses only nulls, read the nulls from the pages in + // range. If nulls are preread, return those minus any skipped. if (presetNulls_) { VELOX_CHECK_LE(numValues, presetNullsSize_ - presetNullsConsumed_); if (!presetNullsConsumed_ && numValues == presetNullsSize_) { @@ -144,8 +148,8 @@ class ParquetData : public dwio::common::FormatData { readNullsOnly(numValues, nulls); return; } - // There are no column-level nulls in Parquet, only page-level ones, so this - // is always non-null. + // There are no column-level nulls in Parquet, only page-level ones, so + // this is always non-null. nulls = nullptr; } @@ -219,6 +223,7 @@ class ParquetData : public dwio::common::FormatData { const uint32_t maxDefine_; const uint32_t maxRepeat_; int64_t rowsInRowGroup_; + dwio::common::ColumnReaderStatistics& stats_; const tz::TimeZone* sessionTimezone_; std::unique_ptr reader_; diff --git a/velox/dwio/parquet/reader/ParquetReader.cpp b/velox/dwio/parquet/reader/ParquetReader.cpp index 955abc91b8ad..dd44355e2fdd 100644 --- a/velox/dwio/parquet/reader/ParquetReader.cpp +++ b/velox/dwio/parquet/reader/ParquetReader.cpp @@ -37,6 +37,21 @@ bool isParquetReservedKeyword( ? true : false; } + +// An unannotated array in Parquet is a repeated field that is not explicitly +// marked as a LIST logical type. If current schema element is a repeated field +// and the requested type is an array, we treat the current schema element as an +// unannotated array, and returns true if the element type is compatible with +// the physical type. +bool isCompatible( + const TypePtr& requestedType, + bool isRepeated, + const std::function& isCompatibleFunc) { + return isCompatibleFunc(requestedType) || + (requestedType->isArray() && isRepeated && + isCompatibleFunc(requestedType->asArray().elementType())); +} + } // namespace /// Metadata and options for reading Parquet. @@ -337,21 +352,37 @@ std::unique_ptr ReaderBase::getParquetColumnInfo( TypePtr childRequestedType = nullptr; bool followChild = true; - if (requestedType && requestedType->isRow()) { - auto requestedRowType = - std::dynamic_pointer_cast(requestedType); - if (options_.useColumnNamesForColumnMapping()) { - auto fileTypeIdx = requestedRowType->getChildIdxIfExists(childName); - if (fileTypeIdx.has_value()) { - childRequestedType = requestedRowType->childAt(*fileTypeIdx); + + { + RowTypePtr requestedRowType = nullptr; + if (requestedType) { + if (requestedType->isRow()) { + requestedRowType = + std::dynamic_pointer_cast(requestedType); + } else if ( + requestedType->isArray() && isRepeated && + requestedType->asArray().elementType()->isRow()) { + // Handle the case of unannotated array of structs (repeated group + // without LIST annotation). + requestedRowType = std::dynamic_pointer_cast( + requestedType->asArray().elementType()); } - } else { - // Handle schema evolution. - if (i < requestedRowType->size()) { - columnNames.push_back(requestedRowType->nameOf(i)); - childRequestedType = requestedRowType->childAt(i); + } + + if (requestedRowType) { + if (options_.useColumnNamesForColumnMapping()) { + auto fileTypeIdx = requestedRowType->getChildIdxIfExists(childName); + if (fileTypeIdx.has_value()) { + childRequestedType = requestedRowType->childAt(*fileTypeIdx); + } } else { - followChild = false; + // Handle schema evolution. + if (i < requestedRowType->size()) { + columnNames.push_back(requestedRowType->nameOf(i)); + childRequestedType = requestedRowType->childAt(i); + } else { + followChild = false; + } } } } @@ -722,6 +753,8 @@ TypePtr ReaderBase::convertType( static std::string_view kTypeMappingErrorFmtStr = "Converted type {} is not allowed for requested type {}"; + const bool isRepeated = schemaElement.__isset.repetition_type && + schemaElement.repetition_type == thrift::FieldRepetitionType::REPEATED; if (schemaElement.__isset.converted_type) { switch (schemaElement.converted_type) { case thrift::ConvertedType::INT_8: @@ -732,10 +765,16 @@ TypePtr ReaderBase::convertType( "{} converted type can only be set for value of thrift::Type::INT32", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TINYINT || - requestedType->kind() == TypeKind::SMALLINT || - requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::TINYINT || + type->kind() == TypeKind::SMALLINT || + type->kind() == TypeKind::INTEGER || + type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "TINYINT", requestedType->toString()); @@ -749,9 +788,15 @@ TypePtr ReaderBase::convertType( "{} converted type can only be set for value of thrift::Type::INT32", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::SMALLINT || - requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::SMALLINT || + type->kind() == TypeKind::INTEGER || + type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "SMALLINT", requestedType->toString()); @@ -765,8 +810,14 @@ TypePtr ReaderBase::convertType( "{} converted type can only be set for value of thrift::Type::INT32", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::INTEGER || + type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "INTEGER", requestedType->toString()); @@ -780,7 +831,13 @@ TypePtr ReaderBase::convertType( "{} converted type can only be set for value of thrift::Type::INT32", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "BIGINT", requestedType->toString()); @@ -792,7 +849,11 @@ TypePtr ReaderBase::convertType( thrift::Type::INT32, "DATE converted type can only be set for value of thrift::Type::INT32"); VELOX_CHECK( - !requestedType || requestedType->isDate(), + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { return type->isDate(); }), kTypeMappingErrorFmtStr, "DATE", requestedType->toString()); @@ -805,7 +866,13 @@ TypePtr ReaderBase::convertType( thrift::Type::INT64, "TIMESTAMP_MICROS or TIMESTAMP_MILLIS converted type can only be set for value of thrift::Type::INT64"); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TIMESTAMP, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::TIMESTAMP; + }), kTypeMappingErrorFmtStr, "TIMESTAMP", requestedType->toString()); @@ -823,7 +890,10 @@ TypePtr ReaderBase::convertType( auto type = DECIMAL(schemaElementPrecision, schemaElementScale); if (requestedType) { VELOX_CHECK( - requestedType->isDecimal(), + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { return type->isDecimal(); }), kTypeMappingErrorFmtStr, "DECIMAL", requestedType->toString()); @@ -832,20 +902,30 @@ TypePtr ReaderBase::convertType( // the scale of the file type and requested type must match while // precision may be larger. if (requestedType->isShortDecimal()) { - const auto& shortDecimalType = requestedType->asShortDecimal(); VELOX_CHECK( - type->isShortDecimal() && - shortDecimalType.precision() >= schemaElementPrecision && - shortDecimalType.scale() == schemaElementScale, + isCompatible( + requestedType, + isRepeated, + [&](const TypePtr& type) { + return type->isShortDecimal() && + type->asShortDecimal().precision() >= + schemaElementPrecision && + type->asShortDecimal().scale() == schemaElementScale; + }), kTypeMappingErrorFmtStr, type->toString(), requestedType->toString()); } else { - const auto& longDecimalType = requestedType->asLongDecimal(); VELOX_CHECK( - type->isLongDecimal() && - longDecimalType.precision() >= schemaElementPrecision && - longDecimalType.scale() == schemaElementScale, + isCompatible( + requestedType, + isRepeated, + [&](const TypePtr& type) { + return type->isLongDecimal() && + type->asLongDecimal().precision() >= + schemaElementPrecision && + type->asLongDecimal().scale() == schemaElementScale; + }), kTypeMappingErrorFmtStr, type->toString(), requestedType->toString()); @@ -859,7 +939,13 @@ TypePtr ReaderBase::convertType( case thrift::Type::BYTE_ARRAY: case thrift::Type::FIXED_LEN_BYTE_ARRAY: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::VARCHAR, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::VARCHAR; + }), kTypeMappingErrorFmtStr, "VARCHAR", requestedType->toString()); @@ -874,7 +960,13 @@ TypePtr ReaderBase::convertType( thrift::Type::BYTE_ARRAY, "ENUM converted type can only be set for value of thrift::Type::BYTE_ARRAY"); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::VARCHAR, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::VARCHAR; + }), kTypeMappingErrorFmtStr, "VARCHAR", requestedType->toString()); @@ -897,15 +989,27 @@ TypePtr ReaderBase::convertType( switch (schemaElement.type) { case thrift::Type::type::BOOLEAN: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::BOOLEAN, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::BOOLEAN; + }), kTypeMappingErrorFmtStr, "BOOLEAN", requestedType->toString()); return BOOLEAN(); case thrift::Type::type::INT32: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::INTEGER || + type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "INTEGER", requestedType->toString()); @@ -915,47 +1019,84 @@ TypePtr ReaderBase::convertType( if (schemaElement.__isset.logicalType && schemaElement.logicalType.__isset.TIMESTAMP) { VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TIMESTAMP, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::TIMESTAMP; + }), kTypeMappingErrorFmtStr, "TIMESTAMP", requestedType->toString()); return TIMESTAMP(); } VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "BIGINT", requestedType->toString()); return BIGINT(); case thrift::Type::type::INT96: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TIMESTAMP, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::TIMESTAMP; + }), kTypeMappingErrorFmtStr, "TIMESTAMP", requestedType->toString()); return TIMESTAMP(); // INT96 only maps to a timestamp case thrift::Type::type::FLOAT: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::REAL || - requestedType->kind() == TypeKind::DOUBLE, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::REAL || + type->kind() == TypeKind::DOUBLE; + }), kTypeMappingErrorFmtStr, "REAL", requestedType->toString()); return REAL(); case thrift::Type::type::DOUBLE: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::DOUBLE, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::DOUBLE; + }), kTypeMappingErrorFmtStr, "DOUBLE", requestedType->toString()); return DOUBLE(); case thrift::Type::type::BYTE_ARRAY: case thrift::Type::type::FIXED_LEN_BYTE_ARRAY: - if (requestedType && requestedType->isVarchar()) { + if (requestedType && + isCompatible(requestedType, isRepeated, [](const TypePtr& type) { + return type->isVarchar(); + })) { return VARCHAR(); } else { VELOX_CHECK( - !requestedType || requestedType->isVarbinary(), + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { return type->isVarbinary(); }), kTypeMappingErrorFmtStr, "VARBINARY", requestedType->toString()); @@ -1178,14 +1319,21 @@ class ParquetRowReader::Impl { std::optional estimatedRowSize() const { auto index = nextRowGroupIdsIdx_ < 1 ? 0 : rowGroupIds_[nextRowGroupIdsIdx_ - 1]; - return readerBase_->rowGroupUncompressedSize( - index, *readerBase_->schemaWithId()) / + if (index == lastRowGroupWithRowEstimate_) { + return estimatedRowSize_; + } + estimatedRowSize_ = readerBase_->rowGroupUncompressedSize( + index, *readerBase_->schemaWithId()) / rowGroups_[index].num_rows; + lastRowGroupWithRowEstimate_ = index; + return estimatedRowSize_; } void updateRuntimeStats(dwio::common::RuntimeStatistics& stats) const { stats.skippedStrides += skippedStrides_; stats.processedStrides += rowGroupIds_.size(); + stats.columnReaderStatistics.pageLoadTimeNs += + columnReaderStats_.pageLoadTimeNs; } void resetFilterCaches() { @@ -1237,6 +1385,9 @@ class ParquetRowReader::Impl { ParquetStatsContext parquetStatsContext_; dwio::common::ColumnReaderStatistics columnReaderStats_; + + mutable std::optional estimatedRowSize_; + mutable int32_t lastRowGroupWithRowEstimate_{-1}; }; ParquetRowReader::ParquetRowReader( diff --git a/velox/dwio/parquet/tests/examples/nested_array_struct.parquet b/velox/dwio/parquet/tests/examples/nested_array_struct.parquet new file mode 100644 index 000000000000..41a43fa35d39 Binary files /dev/null and b/velox/dwio/parquet/tests/examples/nested_array_struct.parquet differ diff --git a/velox/dwio/parquet/tests/examples/proto_repeated_string.parquet b/velox/dwio/parquet/tests/examples/proto_repeated_string.parquet new file mode 100644 index 000000000000..8a7eea601d01 Binary files /dev/null and b/velox/dwio/parquet/tests/examples/proto_repeated_string.parquet differ diff --git a/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp b/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp index 5145dcfdc8ca..87c763d78994 100644 --- a/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp +++ b/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp @@ -31,11 +31,13 @@ TEST_F(ParquetPageReaderTest, smallPage) { auto headerSize = file->getLength(); auto inputStream = std::make_unique( std::move(file), 0, headerSize, *leafPool_, LogType::TEST); + dwio::common::ColumnReaderStatistics stats; auto pageReader = std::make_unique( std::move(inputStream), *leafPool_, common::CompressionKind::CompressionKind_GZIP, - headerSize); + headerSize, + stats); auto header = pageReader->readPageHeader(); EXPECT_EQ(header.type, thrift::PageType::type::DATA_PAGE); EXPECT_EQ(header.uncompressed_page_size, 16950); @@ -50,6 +52,7 @@ TEST_F(ParquetPageReaderTest, smallPage) { auto maxValue = header.data_page_header.statistics.max_value; EXPECT_EQ(minValue, expectedMinValue); EXPECT_EQ(maxValue, expectedMaxValue); + EXPECT_GT(stats.pageLoadTimeNs, 0); } TEST_F(ParquetPageReaderTest, largePage) { @@ -59,11 +62,13 @@ TEST_F(ParquetPageReaderTest, largePage) { auto headerSize = file->getLength(); auto inputStream = std::make_unique( std::move(file), 0, headerSize, *leafPool_, LogType::TEST); + dwio::common::ColumnReaderStatistics stats; auto pageReader = std::make_unique( std::move(inputStream), *leafPool_, common::CompressionKind::CompressionKind_GZIP, - headerSize); + headerSize, + stats); auto header = pageReader->readPageHeader(); EXPECT_EQ(header.type, thrift::PageType::type::DATA_PAGE); @@ -79,6 +84,7 @@ TEST_F(ParquetPageReaderTest, largePage) { auto maxValue = header.data_page_header.statistics.max_value; EXPECT_EQ(minValue, expectedMinValue); EXPECT_EQ(maxValue, expectedMaxValue); + EXPECT_GT(stats.pageLoadTimeNs, 0); } TEST_F(ParquetPageReaderTest, corruptedPageHeader) { @@ -92,11 +98,13 @@ TEST_F(ParquetPageReaderTest, corruptedPageHeader) { // In the corrupted_page_header, the min_value length is set incorrectly on // purpose. This is to simulate the situation where the Parquet Page Header is // corrupted. And an error is expected to be thrown. + dwio::common::ColumnReaderStatistics stats; auto pageReader = std::make_unique( std::move(inputStream), *leafPool_, common::CompressionKind::CompressionKind_GZIP, - headerSize); + headerSize, + stats); EXPECT_THROW(pageReader->readPageHeader(), VeloxException); } diff --git a/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp b/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp index 96c4a3a40513..5cef59993eaf 100644 --- a/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp +++ b/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp @@ -539,8 +539,77 @@ TEST_F(ParquetTableScanTest, array) { vector, })); - assertSelectWithFilter( - {"repeatedInt"}, {}, "", "SELECT UNNEST(array[array[1,2,3]])"); + assertSelectWithFilter({"repeatedInt"}, {}, "", "SELECT [1,2,3]"); + + // Set the requested type for unannotated array. + auto rowType = ROW({"repeatedInt"}, {ARRAY(INTEGER())}); + auto plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits({makeSplit(getExampleFilePath("old_repeated_int.parquet"))}) + .assertResults("SELECT [1,2,3]"); + + // Throws when reading repeated values as scalar type. + rowType = ROW({"repeatedInt"}, {INTEGER()}); + plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits({makeSplit(getExampleFilePath("old_repeated_int.parquet"))}) + .assertResults(""), + "Requested type must be array"); + + rowType = ROW({"mystring"}, {ARRAY(VARCHAR())}); + plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits({makeSplit(getExampleFilePath("proto_repeated_string.parquet"))}) + .assertResults( + "SELECT UNNEST(array[array['hello', 'world'], array['good','bye'], array['one', 'two', 'three']])"); + + rowType = + ROW({"primitive", "myComplex"}, + {INTEGER(), + ARRAY( + ROW({"id", "repeatedMessage"}, + {INTEGER(), ARRAY(ROW({"someId"}, {INTEGER()}))}))}); + plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + + // Construct the expected vector. + auto someIdVector = makeArrayOfRowVector( + ROW({"someId"}, {INTEGER()}), + { + {variant::row({3})}, + {variant::row({6})}, + {variant::row({9})}, + }); + auto rowVector = makeRowVector( + {"id", "repeatedMessage"}, + { + makeFlatVector({1, 4, 7}), + someIdVector, + }); + auto expected = makeRowVector( + {"primitive", "myComplex"}, + { + makeFlatVector({2, 5, 8}), + makeArrayVector({0, 1, 2}, rowVector), + }); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .connectorSessionProperty( + kHiveConnectorId, + connector::hive::HiveConfig::kParquetUseColumnNamesSession, + "true") + .splits({makeSplit(getExampleFilePath("nested_array_struct.parquet"))}) + .assertResults(expected); } // Optional array with required elements. diff --git a/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp b/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp index d22cbba6389b..61e260ba8d72 100644 --- a/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp +++ b/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp @@ -77,6 +77,7 @@ class ParquetWriterTest : public ParquetTestBase { }; inline static const std::string kHiveConnectorId = "test-hive"; + dwio::common::ColumnReaderStatistics stats; }; class ArrowMemoryPool final : public ::arrow::MemoryPool { @@ -199,7 +200,8 @@ TEST_F(ParquetWriterTest, dictionaryEncodingWithDictionaryPageSize) { std::move(inputStream), *leafPool_, colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); + colChunkPtr.totalCompressedSize(), + stats); return pageReader->readPageHeader(); } constexpr int64_t kFirstDataPageCompressedSize = 1291; @@ -215,7 +217,8 @@ TEST_F(ParquetWriterTest, dictionaryEncodingWithDictionaryPageSize) { std::move(inputStream), *leafPool_, colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); + colChunkPtr.totalCompressedSize(), + stats); return pageReader->readPageHeader(); }; @@ -367,7 +370,8 @@ TEST_F(ParquetWriterTest, dictionaryEncodingOff) { std::move(inputStream), *leafPool_, colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); + colChunkPtr.totalCompressedSize(), + stats); return pageReader->readPageHeader(); }; @@ -534,7 +538,8 @@ TEST_F(ParquetWriterTest, testPageSizeAndBatchSizeConfiguration) { std::move(inputStream), *leafPool_, colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); + colChunkPtr.totalCompressedSize(), + stats); return pageReader->readPageHeader(); }; @@ -681,7 +686,8 @@ TEST_F(ParquetWriterTest, toggleDataPageVersion) { std::move(inputStream), *leafPool_, colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); + colChunkPtr.totalCompressedSize(), + stats); return pageReader->readPageHeader().type; }; @@ -822,7 +828,7 @@ TEST_F(ParquetWriterTest, parquetWriteWithArrowMemoryPool) { TEST_F(ParquetWriterTest, updateWriterOptionsFromHiveConfig) { std::unordered_map configFromFile = { - {parquet::WriterOptions::kParquetSessionWriteTimestampUnit, "3"}}; + {parquet::WriterOptions::kParquetHiveConnectorWriteTimestampUnit, "3"}}; const config::ConfigBase connectorConfig(std::move(configFromFile)); const config::ConfigBase connectorSessionProperties({}); @@ -945,6 +951,44 @@ TEST_F(ParquetWriterTest, dictionaryEncodedVector) { EXPECT_FALSE(wrappedVector->wrappedVector()->isFlatEncoding()); writeToFile(makeRowVector({wrappedVector})); +} + +TEST_F(ParquetWriterTest, allNulls) { + auto schema = ROW({"c0"}, {INTEGER()}); + const int64_t kRows = 4096; + // Create a column with all elements being null. + auto nulls = makeNulls(kRows, [](auto /*row*/) { return true; }); + auto flatVector = std::make_shared>( + pool_.get(), + schema->childAt(0), + nulls, + kRows, + /*values=*/nullptr, + std::vector()); + auto data = std::make_shared( + pool_.get(), schema, nullptr, kRows, std::vector{flatVector}); + + // Create an in-memory writer. + auto sink = std::make_unique( + 200 * 1024 * 1024, + dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto sinkPtr = sink.get(); + facebook::velox::parquet::WriterOptions writerOptions; + writerOptions.memoryPool = leafPool_.get(); + + auto writer = std::make_unique( + std::move(sink), writerOptions, rootPool_, schema); + writer->write(data); + writer->close(); + + dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + auto reader = createReaderInMemory(*sinkPtr, readerOptions); + + ASSERT_EQ(reader->numberOfRows(), kRows); + ASSERT_EQ(*reader->rowType(), *schema); + + auto rowReader = createRowReaderWithSchema(std::move(reader), schema); + assertReadWithReaderAndExpected(schema, *rowReader, data, *leafPool_); }; } // namespace diff --git a/velox/dwio/parquet/writer/Writer.cpp b/velox/dwio/parquet/writer/Writer.cpp index 2444fa247cc5..d2cccdf2189b 100644 --- a/velox/dwio/parquet/writer/Writer.cpp +++ b/velox/dwio/parquet/writer/Writer.cpp @@ -564,7 +564,8 @@ void WriterOptions::processConfigs( parquetWriteTimestampUnit = getTimestampUnit(session, kParquetSessionWriteTimestampUnit).has_value() ? getTimestampUnit(session, kParquetSessionWriteTimestampUnit) - : getTimestampUnit(connectorConfig, kParquetSessionWriteTimestampUnit); + : getTimestampUnit( + connectorConfig, kParquetHiveConnectorWriteTimestampUnit); } if (!parquetWriteTimestampTimeZone) { parquetWriteTimestampTimeZone = parquetWriterOptions->sessionTimezoneName; diff --git a/velox/dwio/text/reader/TextReader.cpp b/velox/dwio/text/reader/TextReader.cpp index a157144ea462..3d319b6faba7 100644 --- a/velox/dwio/text/reader/TextReader.cpp +++ b/velox/dwio/text/reader/TextReader.cpp @@ -39,6 +39,7 @@ constexpr const char* kTextfileCompressionExtensionZst = ".zst"; static std::string emptyString = std::string(); namespace { +constexpr const int32_t kDecompressionBufferFactor = 3; void resizeVector( BaseVector* FOLLY_NULLABLE data, @@ -191,34 +192,15 @@ TextRowReader::TextRowReader( } limit_ = std::numeric_limits::max(); - /** - * The output buffer for decompression is allocated based on the - * uncompressed length of the stream. - * - * For decompressors other than ZlibDecompressor, the uncompressed length is - * obtained via getDecompressedLength, and blockSize serves only as a - * fallbak when getDecompressedLength fails to return a valid length. - * - * ZlibDecompressor does not implement getDecompressedLength because the - * DEFLATE algorithm used by zlib does not inherently includes the - * uncompressed length in the compressed stream. As a result, blockSize is - * used to set z_stream.avail_out during decompression to ensure enough - * buffer allocated for the output. Since zlib requires avail_out to be a - * uInt (unsigned int), blockSize is set to std::numeric_limits::max() for full compatibility. - */ - const auto blockSize = - (contents_->compression == CompressionKind::CompressionKind_ZLIB || - contents_->compression == CompressionKind::CompressionKind_GZIP) - ? std::numeric_limits::max() - : std::numeric_limits::max(); - contents_->inputStream = contents_->input->loadCompleteFile(); auto name = contents_->inputStream->getName(); contents_->decompressedInputStream = createDecompressor( contents_->compression, std::move(contents_->inputStream), - blockSize, + // An estimated value used as the output buffer size for the zlib + // decompressor, and as the fallback value of the decompressed length + // for other decompressors. + kDecompressionBufferFactor * contents_->fileLength, contents_->pool, contents_->compressionOptions, fmt::format("Text Reader: Stream {}", name), @@ -559,6 +541,24 @@ TextRowReader::getString(TextRowReader& th, bool& isNull, DelimType& delim) { return th.ownedString_; } +template +void TextRowReader::setValueFromString( + const std::string& str, + BaseVector* data, + vector_size_t insertionRow, + std::function(const std::string&)> convert) { + if ((atEOF_ && atSOL_) || data == nullptr) { + return; + } + auto flatVector = data->asChecked>(); + auto result = str.empty() ? std::nullopt : convert(str); + if (result) { + flatVector->set(insertionRow, *result); + } else { + flatVector->setNull(insertionRow, true); + } +} + uint8_t TextRowReader::getByte(DelimType& delim) { setNone(delim); auto v = getByteUnchecked(delim); @@ -1052,8 +1052,19 @@ void TextRowReader::readElement( getInteger, data, insertionRow, delim); break; case TypeKind::INTEGER: - putValue( - getInteger, data, insertionRow, delim); + if (reqT->isDate()) { + const std::string& str = getString(*this, isNull, delim); + setValueFromString( + str, + data, + insertionRow, + [](const std::string& s) -> std::optional { + return DATE()->toDays(s); + }); + } else { + putValue( + getInteger, data, insertionRow, delim); + } break; default: VELOX_FAIL( @@ -1065,10 +1076,61 @@ void TextRowReader::readElement( break; case TypeKind::BIGINT: - putValue( - getInteger, data, insertionRow, delim); + if (reqT->isShortDecimal()) { + const std::string& str = getString(*this, isNull, delim); + auto decimalParams = getDecimalPrecisionScale(*reqT); + const auto precision = decimalParams.first; + const auto scale = decimalParams.second; + setValueFromString( + str, + data, + insertionRow, + [precision, scale](const std::string& s) -> std::optional { + int64_t v = 0; + const auto status = DecimalUtil::castFromString( + StringView(s.data(), static_cast(s.size())), + precision, + scale, + v); + return status.ok() ? std::optional(v) : std::nullopt; + }); + } else { + putValue( + getInteger, data, insertionRow, delim); + } break; + case TypeKind::HUGEINT: { + const std::string& str = getString(*this, isNull, delim); + if (reqT->isLongDecimal()) { + auto decimalParams = getDecimalPrecisionScale(*reqT); + const auto precision = decimalParams.first; + const auto scale = decimalParams.second; + setValueFromString( + str, + data, + insertionRow, + [precision, + scale](const std::string& s) -> std::optional { + int128_t v = 0; + const auto status = DecimalUtil::castFromString( + StringView(s.data(), static_cast(s.size())), + precision, + scale, + v); + return status.ok() ? std::optional(v) : std::nullopt; + }); + } else { + setValueFromString( + str, + data, + insertionRow, + [](const std::string& s) -> std::optional { + return HugeInt::parse(s); + }); + } + break; + } case TypeKind::SMALLINT: switch (reqT->kind()) { case TypeKind::BIGINT: @@ -1639,17 +1701,4 @@ uint64_t TextReader::getFileLength() const { return contents_->fileLength; } -uint64_t TextReader::getMemoryUse() { - uint64_t memory = std::min( - uint64_t(contents_->fileLength), - contents_->input->getInputStream()->getNaturalReadSize()); - - // Decompressor needs a buffer. - if (contents_->compression != CompressionKind::CompressionKind_NONE) { - memory *= 3; - } - - return memory; -} - } // namespace facebook::velox::text diff --git a/velox/dwio/text/reader/TextReader.h b/velox/dwio/text/reader/TextReader.h index e15635087051..435de81c35a7 100644 --- a/velox/dwio/text/reader/TextReader.h +++ b/velox/dwio/text/reader/TextReader.h @@ -85,8 +85,6 @@ class TextReader : public dwio::common::Reader { uint64_t getFileLength() const; - uint64_t getMemoryUse(); - private: ReaderOptions options_; mutable std::shared_ptr typeWithId_; @@ -206,6 +204,13 @@ class TextRowReader : public dwio::common::RowReader { vector_size_t insertionRow, DelimType& delim); + template + void setValueFromString( + const std::string& str, + BaseVector* FOLLY_NULLABLE data, + vector_size_t insertionRow, + std::function(const std::string&)> convert); + const std::shared_ptr contents_; const std::shared_ptr schemaWithId_; const std::shared_ptr& scanSpec_; diff --git a/velox/dwio/text/tests/reader/TextReaderTest.cpp b/velox/dwio/text/tests/reader/TextReaderTest.cpp index b1b3bbda23e6..73669ea87b20 100644 --- a/velox/dwio/text/tests/reader/TextReaderTest.cpp +++ b/velox/dwio/text/tests/reader/TextReaderTest.cpp @@ -28,6 +28,10 @@ namespace facebook::velox::text { namespace { +int32_t parseDate(const std::string& text) { + return DATE()->toDays(text); +} + class TextReaderTest : public testing::Test, public velox::test::VectorTestBase { protected: @@ -1639,6 +1643,95 @@ TEST_F(TextReaderTest, varbinaryUnsuccessfulDecoding) { EXPECT_EQ(binaryVector->valueAt(1), StringView("Another@Invalid#String")); } +TEST_F(TextReaderTest, logicalTypes) { + auto expected = makeRowVector( + {makeNullableFlatVector( + {0, + 123, + -1234567, + 999999999999999, + std::nullopt, + 4242, + -1, + std::nullopt, + 314159265358979, + 77777, + 100000000000000, + -5432199, + std::nullopt, + 1234, + -999999999999999, + 999999999999999}, + DECIMAL(15, 2)), + makeNullableFlatVector( + {0, + HugeInt::parse("999999999999999999999"), + HugeInt::parse("123456789012345678901234567890"), + HugeInt::parse("-99999999999999999999999999"), + HugeInt::parse("88888888888888888888"), + std::nullopt, + 1, + std::nullopt, + HugeInt::parse("27182818284590452353612"), + HugeInt::parse("-123456789012345678999"), + HugeInt::parse("12345678901234567890123456789012345678"), + 987654321012, + std::nullopt, + 5678, + -123, + HugeInt::parse("99999999999999999999999999999999")}, + DECIMAL(38, 2)), + makeNullableFlatVector( + { + parseDate("1970-01-01"), + parseDate("2024-02-29"), + parseDate("1900-01-01"), + parseDate("2099-12-31"), + parseDate("2001-09-11"), + parseDate("2025-09-10"), + std::nullopt, + std::nullopt, + parseDate("1999-12-31"), + parseDate("2012-12-21"), + parseDate("2200-01-01"), + parseDate("1988-08-08"), + parseDate("1969-07-20"), + parseDate("2000-01-01"), + parseDate("1800-06-15"), + parseDate("2500-12-31"), + }, + DATE())}); + + auto type = + ROW({{"c0", DECIMAL(15, 2)}, {"c1", DECIMAL(38, 2)}, {"c2", DATE()}}); + + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/logical_types.gz"); + + auto readFile = std::make_shared(path); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + ASSERT_EQ(rowReader->next(10, result), 10); + for (int i = 0; i < 10; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 6); + for (int i = 0; i < 6; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 10 + i)); + } +} + TEST_F(TextReaderTest, nestedRows) { auto nestedRowChildren = std::vector{ makeFlatVector({42, 100, -5, 0, 999}), diff --git a/velox/dwio/text/tests/reader/examples/logical_types.gz b/velox/dwio/text/tests/reader/examples/logical_types.gz new file mode 100644 index 000000000000..3a20dc8ceeef Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/logical_types.gz differ diff --git a/velox/examples/ScanAndSort.cpp b/velox/examples/ScanAndSort.cpp index aee9144b172d..fa9056b733cc 100644 --- a/velox/examples/ScanAndSort.cpp +++ b/velox/examples/ScanAndSort.cpp @@ -135,7 +135,8 @@ int main(int argc, char** argv) { writerPlanFragment, /*destination=*/0, core::QueryCtx::create(executor.get()), - exec::Task::ExecutionMode::kSerial); + exec::Task::ExecutionMode::kSerial, + exec::Consumer{}); // next() starts execution using the client thread. The loop pumps output // vectors out of the task (there are none in this query fragment). @@ -165,7 +166,8 @@ int main(int argc, char** argv) { readPlanFragment, /*destination=*/0, core::QueryCtx::create(executor.get()), - exec::Task::ExecutionMode::kSerial); + exec::Task::ExecutionMode::kSerial, + exec::Consumer{}); // Now that we have the query fragment and Task structure set up, we will // add data to it via `splits`. diff --git a/velox/exec/Aggregate.cpp b/velox/exec/Aggregate.cpp index dd07b7e44367..003f669e66ae 100644 --- a/velox/exec/Aggregate.cpp +++ b/velox/exec/Aggregate.cpp @@ -20,7 +20,6 @@ #include "velox/exec/AggregateCompanionAdapter.h" #include "velox/exec/AggregateCompanionSignatures.h" #include "velox/exec/AggregateWindow.h" -#include "velox/expression/SignatureBinder.h" namespace facebook::velox::exec { @@ -295,24 +294,6 @@ std::unique_ptr Aggregate::create( const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& config) { - // TODO(timaou, kletkavrubashku): Reneable the validation once "regr_slope" - // signature is fixed - // - // Validate the result type. if (isPartialOutput(step)) { - // auto intermediateType = Aggregate::intermediateType(name, argTypes); - // VELOX_CHECK( - // resultType->equivalent(*intermediateType), - // "Intermediate type mismatch. Expected: {}, actual: {}", - // intermediateType->toString(), - // resultType->toString()); - // } else { - // auto finalType = Aggregate::finalType(name, argTypes); - // VELOX_CHECK( - // resultType->equivalent(*finalType), - // "Final type mismatch. Expected: {}, actual: {}", - // finalType->toString(), - // resultType->toString()); - // } // Lookup the function in the new registry first. if (auto func = getAggregateFunctionEntry(name)) { return func->factory(step, argTypes, resultType, config); diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index 69e7e16bd439..a3052ea9daca 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -505,6 +505,10 @@ using AggregateFunctionFactory = std::function( const core::QueryConfig& config)>; struct AggregateFunctionMetadata { + /// True if results of the aggregation ignore duplicate values. + /// For example, min and max ignore duplicates while sum does not. + bool ignoreDuplicates{false}; + /// True if results of the aggregation depend on the order of inputs. For /// example, array_agg is order sensitive while count is not. bool orderSensitive{true}; diff --git a/velox/exec/AggregateInfo.cpp b/velox/exec/AggregateInfo.cpp index a3f28c0684dc..f39ed7e3b6f3 100644 --- a/velox/exec/AggregateInfo.cpp +++ b/velox/exec/AggregateInfo.cpp @@ -81,11 +81,10 @@ std::vector toAggregateInfo( arg->toString()); } } + const auto& name = aggregate.call->name(); - info.distinct = aggregate.distinct; - info.intermediateType = resolveAggregateFunction( - aggregate.call->name(), aggregate.rawInputTypes) - .second; + info.intermediateType = + resolveAggregateFunction(name, aggregate.rawInputTypes).second; // Setup aggregation mask: convert the Variable Reference name to the // channel (projection) index, if there is a mask. @@ -98,7 +97,7 @@ std::vector toAggregateInfo( auto index = numKeys + i; const auto& aggResultType = outputType->childAt(index); info.function = Aggregate::create( - aggregate.call->name(), + name, isPartialOutput(step) ? core::AggregationNode::Step::kPartial : core::AggregationNode::Step::kSingle, aggregate.rawInputTypes, @@ -114,10 +113,13 @@ std::vector toAggregateInfo( info.function->setLambdaExpressions(lambdas, expressionEvaluator); } - // Ignore sorting properties if aggregate function is not sensitive to the - // order of inputs. - auto* entry = getAggregateFunctionEntry(aggregate.call->name()); + // 1. Ignore duplicates property + // if aggregate function is not sensitive to duplicates. + // 2. Ignore sorting properties + // if aggregate function is not sensitive to the order of inputs. + auto* entry = getAggregateFunctionEntry(name); const auto& metadata = entry->metadata; + info.distinct = !metadata.ignoreDuplicates && aggregate.distinct; if (metadata.orderSensitive) { // Sorting keys and orders. const auto numSortingKeys = aggregate.sortingKeys.size(); diff --git a/velox/exec/CMakeLists.txt b/velox/exec/CMakeLists.txt index ba8457f888e7..5606991e9994 100644 --- a/velox/exec/CMakeLists.txt +++ b/velox/exec/CMakeLists.txt @@ -19,14 +19,14 @@ velox_add_library( AggregateCompanionSignatures.cpp AggregateFunctionRegistry.cpp AggregateInfo.cpp - AggregationMasks.cpp AggregateWindow.cpp + AggregationMasks.cpp ArrowStream.cpp AssignUniqueId.cpp BlockingReason.cpp CallbackSink.cpp - ContainerRowSerde.cpp ColumnStatsCollector.cpp + ContainerRowSerde.cpp DistinctAggregations.cpp Driver.cpp EnforceSingleRow.cpp @@ -56,36 +56,34 @@ velox_add_library( MergeSource.cpp NestedLoopJoinBuild.cpp NestedLoopJoinProbe.cpp + SpatialIndex.cpp SpatialJoinBuild.cpp SpatialJoinProbe.cpp Operator.cpp + OperatorTraceReader.cpp + OperatorTraceScan.cpp + OperatorTraceWriter.cpp OperatorUtils.cpp OrderBy.cpp OutputBuffer.cpp OutputBufferManager.cpp - OperatorTraceReader.cpp - OperatorTraceScan.cpp - OperatorTraceWriter.cpp ParallelProject.cpp - TaskStructs.cpp - TaskTraceReader.cpp - TaskTraceWriter.cpp - Trace.cpp - TraceUtil.cpp - PartitionedOutput.cpp PartitionFunction.cpp PartitionStreamingWindowBuild.cpp + PartitionedOutput.cpp PlanNodeStats.cpp PrefixSort.cpp ProbeOperatorState.cpp - RowsStreamingWindowBuild.cpp RowContainer.cpp RowNumber.cpp - ScaledScanController.cpp + RowsStreamingWindowBuild.cpp ScaleWriterLocalPartition.cpp + ScaledScanController.cpp SortBuffer.cpp - SortedAggregations.cpp SortWindowBuild.cpp + SortedAggregations.cpp + SpatialJoinBuild.cpp + SpatialJoinProbe.cpp Spill.cpp SpillFile.cpp Spiller.cpp @@ -95,8 +93,13 @@ velox_add_library( TableWriteMerge.cpp TableWriter.cpp Task.cpp + TaskStructs.cpp + TaskTraceReader.cpp + TaskTraceWriter.cpp TopN.cpp TopNRowNumber.cpp + Trace.cpp + TraceUtil.cpp Unnest.cpp Values.cpp VectorHasher.cpp @@ -108,32 +111,22 @@ velox_add_library( velox_link_libraries( velox_exec - velox_file + velox_arrow_bridge + velox_common_base + velox_common_compression velox_core - velox_vector velox_connector velox_expression + velox_file + velox_presto_serializer velox_time - velox_common_base velox_test_util - velox_arrow_bridge - velox_common_compression + velox_vector ) velox_add_library(velox_cursor Cursor.cpp) -velox_link_libraries( - velox_cursor - velox_core - velox_exception - velox_expression - velox_dwio_common - velox_dwio_dwrf_reader - velox_dwio_dwrf_writer - velox_type_fbhive - velox_presto_serializer - velox_functions_prestosql - velox_aggregates -) + +velox_link_libraries(velox_cursor velox_core velox_exception velox_expression) if(${VELOX_BUILD_TESTING}) add_subdirectory(fuzzer) diff --git a/velox/exec/Cursor.cpp b/velox/exec/Cursor.cpp index bdbbc0fe50c7..94af91cb2827 100644 --- a/velox/exec/Cursor.cpp +++ b/velox/exec/Cursor.cpp @@ -174,22 +174,25 @@ class TaskCursorBase : public TaskCursor { if (!params.spillDirectory.empty()) { taskSpillDirectory_ = params.spillDirectory + "/" + taskId_; - auto fileSystem = - velox::filesystems::getFileSystem(taskSpillDirectory_, nullptr); - VELOX_CHECK_NOT_NULL(fileSystem, "File System is null!"); - try { - fileSystem->mkdir(taskSpillDirectory_); - } catch (...) { - LOG(ERROR) << "Faield to create task spill directory " - << taskSpillDirectory_ << " base director " - << params.spillDirectory << " exists[" - << std::filesystem::exists(taskSpillDirectory_) << "]"; - - std::rethrow_exception(std::current_exception()); - } + taskSpillDirectoryCb_ = params.spillDirectoryCallback; + if (taskSpillDirectoryCb_ == nullptr) { + auto fileSystem = + velox::filesystems::getFileSystem(taskSpillDirectory_, nullptr); + VELOX_CHECK_NOT_NULL(fileSystem, "File System is null!"); + try { + fileSystem->mkdir(taskSpillDirectory_); + } catch (...) { + LOG(ERROR) << "Faield to create task spill directory " + << taskSpillDirectory_ << " base director " + << params.spillDirectory << " exists[" + << std::filesystem::exists(taskSpillDirectory_) << "]"; + + std::rethrow_exception(std::current_exception()); + } - LOG(INFO) << "Task spill directory[" << taskSpillDirectory_ - << "] created"; + LOG(INFO) << "Task spill directory[" << taskSpillDirectory_ + << "] created"; + } } } @@ -198,6 +201,7 @@ class TaskCursorBase : public TaskCursor { std::shared_ptr queryCtx_; core::PlanFragment planFragment_; std::string taskSpillDirectory_; + std::function taskSpillDirectoryCb_; private: std::shared_ptr executor_; @@ -222,7 +226,14 @@ class MultiThreadedTaskCursor : public TaskCursorBase { std::make_shared(params.bufferedBytes, params.outputPool); // Captured as a shared_ptr by the consumer callback of task_. - auto queue = queue_; + auto queueHolder = std::weak_ptr(queue_); + std::optional spillDiskOpts; + if (!taskSpillDirectory_.empty()) { + spillDiskOpts = common::SpillDiskOptions{ + .spillDirPath = taskSpillDirectory_, + .spillDirCreated = taskSpillDirectoryCb_ == nullptr, + .spillDirCreateCb = taskSpillDirectoryCb_}; + } task_ = Task::create( taskId_, std::move(planFragment_), @@ -230,10 +241,15 @@ class MultiThreadedTaskCursor : public TaskCursorBase { std::move(queryCtx_), Task::ExecutionMode::kParallel, // consumer - [queue, copyResult = params.copyResult]( + [queueHolder, copyResult = params.copyResult, taskId = taskId_]( const RowVectorPtr& vector, bool drained, velox::ContinueFuture* future) { + auto queue = queueHolder.lock(); + if (queue == nullptr) { + LOG(ERROR) << "TaskQueue has been destroyed, taskId: " << taskId; + return exec::BlockingReason::kNotBlocked; + } VELOX_CHECK( !drained, "Unexpected drain in multithreaded task cursor"); if (!vector || !copyResult) { @@ -249,16 +265,18 @@ class MultiThreadedTaskCursor : public TaskCursorBase { return queue->enqueue(std::move(copy), future); }, 0, - [queue](std::exception_ptr) { + std::move(spillDiskOpts), + [queueHolder, taskId = taskId_](std::exception_ptr) { // onError close the queue to unblock producers and consumers. // moveNext will handle rethrowing the error once it's // unblocked. + auto queue = queueHolder.lock(); + if (queue == nullptr) { + LOG(ERROR) << "TaskQueue has been destroyed, taskId: " << taskId; + return; + } queue->close(); }); - - if (!taskSpillDirectory_.empty()) { - task_->setSpillDirectory(taskSpillDirectory_); - } } ~MultiThreadedTaskCursor() override { @@ -371,17 +389,22 @@ class SingleThreadedTaskCursor : public TaskCursorBase { VELOX_CHECK( !queryCtx_->isExecutorSupplied(), "Executor should not be set in serial task cursor"); - + std::optional spillDiskOpts; + if (!taskSpillDirectory_.empty()) { + spillDiskOpts = common::SpillDiskOptions{ + .spillDirPath = taskSpillDirectory_, + .spillDirCreated = true, + .spillDirCreateCb = taskSpillDirectoryCb_}; + } task_ = Task::create( taskId_, std::move(planFragment_), params.destination, std::move(queryCtx_), - Task::ExecutionMode::kSerial); - - if (!taskSpillDirectory_.empty()) { - task_->setSpillDirectory(taskSpillDirectory_); - } + Task::ExecutionMode::kSerial, + std::function{}, + 0, + std::move(spillDiskOpts)); VELOX_CHECK( task_->supportSerialExecutionMode(), diff --git a/velox/exec/Cursor.h b/velox/exec/Cursor.h index 57d1ec8380aa..6cce86d0a9cd 100644 --- a/velox/exec/Cursor.h +++ b/velox/exec/Cursor.h @@ -71,6 +71,12 @@ struct CursorParameters { /// would be built from it. std::string spillDirectory; + /// Callback function to dynamically create or determine the spill directory + /// path at runtime. If provided, this callback is invoked when spilling is + /// needed and must return a valid directory path. This allows for dynamic + /// spill directory creation or path resolution based on runtime conditions. + std::function spillDirectoryCallback; + bool copyResult = true; /// If true, use serial execution mode. Use parallel execution mode diff --git a/velox/exec/Driver.cpp b/velox/exec/Driver.cpp index 202553ef41ba..b356542583b8 100644 --- a/velox/exec/Driver.cpp +++ b/velox/exec/Driver.cpp @@ -145,7 +145,8 @@ std::optional DriverCtx::makeSpillConfig( queryConfig.spillPrefixSortEnabled() ? std::optional(prefixSortConfig()) : std::nullopt, - queryConfig.spillFileCreateConfig()); + queryConfig.spillFileCreateConfig(), + queryConfig.windowSpillMinReadBatchRows()); } std::atomic_uint64_t BlockingState::numBlockedDrivers_{0}; @@ -443,8 +444,10 @@ inline void addInput(Operator* op, const RowVectorPtr& input) { } inline void getOutput(Operator* op, RowVectorPtr& result) { - result = op->getOutput(); - if (FOLLY_UNLIKELY(op->shouldDropOutput())) { + auto output = op->getOutput(); + if (FOLLY_LIKELY(!op->shouldDropOutput())) { + result = std::move(output); // Use move semantics to avoid ref counting + } else { result = nullptr; } } @@ -1102,9 +1105,13 @@ std::string Driver::toString() const { } out << "{Operators: "; - for (auto& op : operators_) { - out << op->toString() << ", "; - } + std::vector opStrs; + opStrs.reserve(operators_.size()); + std::ranges::transform( + operators_, std::back_inserter(opStrs), [](const auto& op) { + return op->toString(); + }); + out << folly::join(", ", opStrs); out << "}"; const auto ocs = opCallStatus(); if (!ocs.empty()) { diff --git a/velox/exec/Exchange.cpp b/velox/exec/Exchange.cpp index 4afdb425e7a6..419fb9dd00ce 100644 --- a/velox/exec/Exchange.cpp +++ b/velox/exec/Exchange.cpp @@ -15,7 +15,9 @@ */ #include "velox/exec/Exchange.h" +#include "velox/common/Casts.h" #include "velox/common/serialization/Serializable.h" +#include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/serializers/CompactRowSerializer.h" @@ -42,16 +44,18 @@ void RemoteConnectorSplit::registerSerDe() { } namespace { -std::unique_ptr getVectorSerdeOptions( - const core::QueryConfig& queryConfig, - VectorSerde::Kind kind) { - std::unique_ptr options = - kind == VectorSerde::Kind::kPresto - ? std::make_unique() - : std::make_unique(); - options->compressionKind = - common::stringToCompressionKind(queryConfig.shuffleCompressionKind()); - return options; +std::unique_ptr mergePages( + std::vector>& pages) { + VELOX_CHECK(!pages.empty()); + std::unique_ptr mergedBufs; + for (const auto& page : pages) { + if (mergedBufs == nullptr) { + mergedBufs = page->getIOBuf(); + } else { + mergedBufs->appendToChain(page->getIOBuf()); + } + } + return mergedBufs; } } // namespace @@ -71,7 +75,9 @@ Exchange::Exchange( driverCtx->queryConfig().preferredOutputBatchBytes()}, serdeKind_{exchangeNode->serdeKind()}, serdeOptions_{getVectorSerdeOptions( - operatorCtx_->driverCtx()->queryConfig(), + common::stringToCompressionKind(operatorCtx_->driverCtx() + ->queryConfig() + .shuffleCompressionKind()), serdeKind_)}, processSplits_{operatorCtx_->driverCtx()->driverId == 0}, driverId_{driverCtx->driverId}, @@ -85,42 +91,42 @@ void Exchange::addRemoteTaskIds(std::vector& remoteTaskIds) { stats_.wlock()->numSplits += remoteTaskIds.size(); } -bool Exchange::getSplits(ContinueFuture* future) { +void Exchange::getSplits(ContinueFuture* future) { if (!processSplits_) { - return false; + return; } if (noMoreSplits_) { - return false; + return; } std::vector remoteTaskIds; for (;;) { exec::Split split; - auto reason = operatorCtx_->task()->getSplitOrFuture( + const auto reason = operatorCtx_->task()->getSplitOrFuture( operatorCtx_->driverCtx()->splitGroupId, planNodeId(), split, *future); - if (reason == BlockingReason::kNotBlocked) { - if (split.hasConnectorSplit()) { - auto remoteSplit = std::dynamic_pointer_cast( - split.connectorSplit); - VELOX_CHECK_NOT_NULL(remoteSplit, "Wrong type of split"); - if (FOLLY_UNLIKELY(splitTracer_ != nullptr)) { - splitTracer_->write(split); - } - remoteTaskIds.push_back(remoteSplit->taskId); - } else { - addRemoteTaskIds(remoteTaskIds); - exchangeClient_->noMoreRemoteTasks(); - noMoreSplits_ = true; - if (atEnd_) { - operatorCtx_->task()->multipleSplitsFinished( - false, stats_.rlock()->numSplits, 0); - recordExchangeClientStats(); - } - return false; - } - } else { + if (reason != BlockingReason::kNotBlocked) { addRemoteTaskIds(remoteTaskIds); - return true; + return; } + + if (split.hasConnectorSplit()) { + auto remoteSplit = + checked_pointer_cast(split.connectorSplit); + if (FOLLY_UNLIKELY(splitTracer_ != nullptr)) { + splitTracer_->write(split); + } + remoteTaskIds.push_back(remoteSplit->taskId); + continue; + } + + addRemoteTaskIds(remoteTaskIds); + exchangeClient_->noMoreRemoteTasks(); + noMoreSplits_ = true; + if (atEnd_) { + operatorCtx_->task()->multipleSplitsFinished( + false, stats_.rlock()->numSplits, 0); + recordExchangeClientStats(); + } + return; } } @@ -130,7 +136,6 @@ BlockingReason Exchange::isBlocked(ContinueFuture* future) { } // Start fetching data right away. Do not wait for all splits to be available. - if (!splitFuture_.valid()) { getSplits(&splitFuture_); } @@ -159,6 +164,7 @@ BlockingReason Exchange::isBlocked(ContinueFuture* future) { } // Block until data becomes available. + VELOX_CHECK(dataFuture.valid()); *future = std::move(dataFuture); return BlockingReason::kWaitForProducer; } @@ -167,22 +173,6 @@ bool Exchange::isFinished() { return atEnd_ && currentPages_.empty(); } -namespace { -std::unique_ptr mergePages( - std::vector>& pages) { - VELOX_CHECK(!pages.empty()); - std::unique_ptr mergedBufs; - for (const auto& page : pages) { - if (mergedBufs == nullptr) { - mergedBufs = page->getIOBuf(); - } else { - mergedBufs->appendToChain(page->getIOBuf()); - } - } - return mergedBufs; -} -} // namespace - RowVectorPtr Exchange::getOutput() { auto* serde = getSerde(); if (serde->supportsAppendInDeserialize()) { @@ -210,59 +200,54 @@ RowVectorPtr Exchange::getOutput() { recordInputStats(rawInputBytes); return result_; } - if (serde->kind() == VectorSerde::Kind::kCompactRow) { - return getOutputFromCompactRows(serde); - } - if (serde->kind() == VectorSerde::Kind::kUnsafeRow) { - return getOutputFromUnsafeRows(serde); - } - VELOX_UNREACHABLE( - "Unsupported serde kind: {}", VectorSerde::kindName(serde->kind())); + return getOutputFromRows(serde); } -RowVectorPtr Exchange::getOutputFromCompactRows(VectorSerde* serde) { +RowVectorPtr Exchange::getOutputFromRows(VectorSerde* serde) { uint64_t rawInputBytes{0}; if (currentPages_.empty()) { - VELOX_CHECK_NULL(compactRowInputStream_); - VELOX_CHECK_NULL(compactRowIterator_); + VELOX_CHECK_NULL(rowInputStream_); + VELOX_CHECK_NULL(rowIterator_); return nullptr; } - if (compactRowInputStream_ == nullptr) { + if (rowInputStream_ == nullptr) { std::unique_ptr mergedBufs = mergePages(currentPages_); rawInputBytes += mergedBufs->computeChainDataLength(); - compactRowPages_ = std::make_unique(std::move(mergedBufs)); - compactRowInputStream_ = compactRowPages_->prepareStreamForDeserialize(); + rowPages_ = std::make_unique(std::move(mergedBufs)); + rowInputStream_ = rowPages_->prepareStreamForDeserialize(); } - auto numRows = kInitialOutputCompactRows; - if (estimatedCompactRowSize_.has_value()) { + auto numRows = kInitialOutputRows; + if (estimatedRowSize_.has_value()) { numRows = std::max( - (preferredOutputBatchBytes_ / estimatedCompactRowSize_.value()), - kInitialOutputCompactRows); + (preferredOutputBatchBytes_ / estimatedRowSize_.value()), + kInitialOutputRows); } + // Check if the serde supports batched deserialization serde->deserialize( - compactRowInputStream_.get(), - compactRowIterator_, + rowInputStream_.get(), + rowIterator_, numRows, outputType_, &result_, pool(), serdeOptions_.get()); + const auto numOutputRows = result_->size(); VELOX_CHECK_GT(numOutputRows, 0); - estimatedCompactRowSize_ = std::max( + estimatedRowSize_ = std::max( result_->estimateFlatSize() / numOutputRows, - estimatedCompactRowSize_.value_or(1L)); + estimatedRowSize_.value_or(1L)); - if (compactRowInputStream_->atEnd() && compactRowIterator_ == nullptr) { + if (rowInputStream_->atEnd() && rowIterator_ == nullptr) { // only clear the input stream if we have reached the end of the row // iterator because row iterator may depend on input stream if serialized // rows are not compressed. - compactRowInputStream_ = nullptr; - compactRowPages_ = nullptr; + rowInputStream_ = nullptr; + rowPages_ = nullptr; currentPages_.clear(); } @@ -270,22 +255,6 @@ RowVectorPtr Exchange::getOutputFromCompactRows(VectorSerde* serde) { return result_; } -RowVectorPtr Exchange::getOutputFromUnsafeRows(VectorSerde* serde) { - uint64_t rawInputBytes{0}; - if (currentPages_.empty()) { - return nullptr; - } - std::unique_ptr mergedBufs = mergePages(currentPages_); - rawInputBytes += mergedBufs->computeChainDataLength(); - auto mergedPages = std::make_unique(std::move(mergedBufs)); - auto source = mergedPages->prepareStreamForDeserialize(); - serde->deserialize( - source.get(), pool(), outputType_, &result_, serdeOptions_.get()); - currentPages_.clear(); - recordInputStats(rawInputBytes); - return result_; -} - void Exchange::recordInputStats(uint64_t rawInputBytes) { auto lockedStats = stats_.wlock(); lockedStats->rawInputBytes += rawInputBytes; @@ -325,7 +294,7 @@ void Exchange::recordExchangeClientStats() { lockedStats->runtimeStats.insert({name, value}); } - auto backgroundCpuTimeMs = + const auto backgroundCpuTimeMs = exchangeClientStats.find(ExchangeClient::kBackgroundCpuTimeMs); if (backgroundCpuTimeMs != exchangeClientStats.end()) { const CpuWallTiming backgroundTiming{ diff --git a/velox/exec/Exchange.h b/velox/exec/Exchange.h index 2e235b698e99..6c2b87da9b4a 100644 --- a/velox/exec/Exchange.h +++ b/velox/exec/Exchange.h @@ -67,12 +67,11 @@ class Exchange : public SourceOperator { protected: virtual VectorSerde* getSerde(); - private: - // When 'estimatedCompactRowSize_' is unset, meaning we haven't materialized + // When 'estimatedRowSize_' is unset, meaning we haven't materialized // and returned any output from this exchange operator, we return this // conservative number of output rows, to make sure memory does not grow too // much. - static constexpr uint64_t kInitialOutputCompactRows = 64; + static constexpr uint64_t kInitialOutputRows = 64; // Invoked to create exchange client for remote tasks. The function shuffles // the source task ids first to randomize the source tasks we fetch data from. @@ -82,11 +81,8 @@ class Exchange : public SourceOperator { // Fetches splits from the task until there are no more splits or task returns // a future that will be complete when more splits arrive. Adds splits to - // exchangeClient_. Returns true if received a future from the task and sets - // the 'future' parameter. Returns false if fetched all splits or if this - // operator is not the first operator in the pipeline and therefore is not - // responsible for fetching splits and adding them to the exchangeClient_. - bool getSplits(ContinueFuture* future); + // exchangeClient_. + void getSplits(ContinueFuture* future); // Fetches runtime stats from ExchangeClient and replaces these in this // operator's stats. @@ -94,9 +90,7 @@ class Exchange : public SourceOperator { void recordInputStats(uint64_t rawInputBytes); - RowVectorPtr getOutputFromCompactRows(VectorSerde* serde); - - RowVectorPtr getOutputFromUnsafeRows(VectorSerde* serde); + RowVectorPtr getOutputFromRows(VectorSerde* serde); const uint64_t preferredOutputBatchBytes_; @@ -125,15 +119,15 @@ class Exchange : public SourceOperator { bool atEnd_{false}; std::default_random_engine rng_{std::random_device{}()}; - // Memory holders needed by compact row serde to perform cursor like reads - // across 'getOutputFromCompactRows' calls. - std::unique_ptr compactRowPages_; - std::unique_ptr compactRowInputStream_; - std::unique_ptr compactRowIterator_; + // Memory holders needed by row serde to perform cursor like reads + // across 'getOutputFromRows' calls. + std::unique_ptr rowPages_; + std::unique_ptr rowInputStream_; + std::unique_ptr rowIterator_; // The estimated bytes per row of the output of this exchange operator // computed from the last processed output. - std::optional estimatedCompactRowSize_; + std::optional estimatedRowSize_; }; } // namespace facebook::velox::exec diff --git a/velox/exec/ExchangeQueue.cpp b/velox/exec/ExchangeQueue.cpp index 7d1c24369b94..6f04f67ae9cb 100644 --- a/velox/exec/ExchangeQueue.cpp +++ b/velox/exec/ExchangeQueue.cpp @@ -128,7 +128,7 @@ void ExchangeQueue::addPromiseLocked( *stalePromise = std::move(it->second); it->second = std::move(promise); } else { - promises_[consumerId] = std::move(promise); + promises_.emplace(consumerId, std::move(promise)); } VELOX_CHECK_LE(promises_.size(), numberOfConsumers_); } diff --git a/velox/exec/ExchangeQueue.h b/velox/exec/ExchangeQueue.h index 4f77360fdbc7..ba9e12c9402d 100644 --- a/velox/exec/ExchangeQueue.h +++ b/velox/exec/ExchangeQueue.h @@ -29,7 +29,7 @@ class SerializedPage { std::function onDestructionCb = nullptr, std::optional numRows = std::nullopt); - ~SerializedPage(); + virtual ~SerializedPage(); /// Returns the size of the serialized data in bytes. uint64_t size() const { diff --git a/velox/exec/FilterProject.h b/velox/exec/FilterProject.h index 79aafcc4a8d6..b5ae557e0fd6 100644 --- a/velox/exec/FilterProject.h +++ b/velox/exec/FilterProject.h @@ -82,6 +82,12 @@ class FilterProject : public Operator { /// tracking is enabled via query config. OperatorStats stats(bool clear) override; + /// Returns the filterNode, call this function before initialize the operator, + /// this field is reset in function initialize. + const std::shared_ptr& filterNode() const { + return filter_; + } + private: // Evaluate filter on all rows. Return number of rows that passed the filter. // Populate filterEvalCtx_.selectedBits and selectedIndices with the indices diff --git a/velox/exec/GroupingSet.h b/velox/exec/GroupingSet.h index 43ce19eed1ce..d71f907f1e37 100644 --- a/velox/exec/GroupingSet.h +++ b/velox/exec/GroupingSet.h @@ -15,13 +15,13 @@ */ #pragma once +#include "velox/common/base/TreeOfLosers.h" #include "velox/exec/AggregateInfo.h" #include "velox/exec/AggregationMasks.h" #include "velox/exec/DistinctAggregations.h" #include "velox/exec/HashTable.h" #include "velox/exec/SortedAggregations.h" #include "velox/exec/Spiller.h" -#include "velox/exec/TreeOfLosers.h" #include "velox/exec/VectorHasher.h" namespace facebook::velox::exec { diff --git a/velox/exec/HashJoinBridge.cpp b/velox/exec/HashJoinBridge.cpp index c9dc19948a00..3a4eaf495418 100644 --- a/velox/exec/HashJoinBridge.cpp +++ b/velox/exec/HashJoinBridge.cpp @@ -452,11 +452,11 @@ uint64_t HashJoinMemoryReclaimer::reclaim( } bool isHashBuildMemoryPool(const memory::MemoryPool& pool) { - return folly::StringPiece(pool.name()).endsWith("HashBuild"); + return pool.name().ends_with("HashBuild"); } bool isHashProbeMemoryPool(const memory::MemoryPool& pool) { - return folly::StringPiece(pool.name()).endsWith("HashProbe"); + return pool.name().ends_with("HashProbe"); } bool needRightSideJoin(core::JoinType joinType) { diff --git a/velox/exec/IndexLookupJoin.cpp b/velox/exec/IndexLookupJoin.cpp index f0e437ec7139..692c481389d9 100644 --- a/velox/exec/IndexLookupJoin.cpp +++ b/velox/exec/IndexLookupJoin.cpp @@ -146,12 +146,12 @@ IndexLookupJoin::IndexLookupJoin( ? outputBatchRows() : std::numeric_limits::max()}, joinType_{joinNode->joinType()}, - includeMatchColumn_(joinNode->includeMatchColumn()), + hasMarker_(joinNode->hasMarker()), numKeys_{joinNode->leftKeys().size()}, probeType_{joinNode->sources()[0]->outputType()}, lookupType_{joinNode->lookupSource()->outputType()}, lookupTableHandle_{joinNode->lookupSource()->tableHandle()}, - lookupConditions_{joinNode->joinConditions()}, + joinConditions_{joinNode->joinConditions()}, lookupColumnHandles_(joinNode->lookupSource()->assignments()), connectorQueryCtx_{operatorCtx_->createConnectorQueryCtx( lookupTableHandle_->connectorId(), @@ -184,11 +184,12 @@ void IndexLookupJoin::initialize() { initLookupInput(); initLookupOutput(); initOutputProjections(); + initFilter(); indexSource_ = connector_->createIndexSource( lookupInputType_, numKeys_, - lookupConditions_, + joinConditions_, lookupOutputType_, lookupTableHandle_, lookupColumnHandles_, @@ -215,14 +216,13 @@ void IndexLookupJoin::initLookupInput() { VELOX_CHECK(lookupInputChannels_.empty()); std::vector lookupInputNames; - lookupInputNames.reserve(numKeys_ + lookupConditions_.size()); + lookupInputNames.reserve(numKeys_ + joinConditions_.size()); std::vector lookupInputTypes; - lookupInputTypes.reserve(numKeys_ + lookupConditions_.size()); - lookupInputChannels_.reserve(numKeys_ + lookupConditions_.size()); + lookupInputTypes.reserve(numKeys_ + joinConditions_.size()); + lookupInputChannels_.reserve(numKeys_ + joinConditions_.size()); SCOPE_EXIT { - VELOX_CHECK_GE( - lookupInputNames.size(), numKeys_ + lookupConditions_.size()); + VELOX_CHECK_GE(lookupInputNames.size(), numKeys_ + joinConditions_.size()); VELOX_CHECK_EQ(lookupInputNames.size(), lookupInputChannels_.size()); lookupInputType_ = ROW(std::move(lookupInputNames), std::move(lookupInputTypes)); @@ -257,19 +257,19 @@ void IndexLookupJoin::initLookupInput() { lookupKeyOrConditionHashers_ = createVectorHashers(probeType_, lookupInputChannels_); }; - if (lookupConditions_.empty()) { + if (joinConditions_.empty()) { return; } - for (const auto& lookupCondition : lookupConditions_) { - const auto indexKeyName = getColumnName(lookupCondition->key); + for (const auto& joinCondition : joinConditions_) { + const auto indexKeyName = getColumnName(joinCondition->key); VELOX_USER_CHECK_EQ(lookupIndexColumnSet.count(indexKeyName), 0); lookupIndexColumnSet.insert(indexKeyName); const auto indexKeyType = lookupType_->findChild(indexKeyName); if (const auto inCondition = std::dynamic_pointer_cast( - lookupCondition)) { + joinCondition)) { const auto conditionInputName = getColumnName(inCondition->list); const auto conditionInputChannel = probeType_->getChildIdx(conditionInputName); @@ -290,7 +290,7 @@ void IndexLookupJoin::initLookupInput() { if (const auto betweenCondition = std::dynamic_pointer_cast( - lookupCondition)) { + joinCondition)) { addBetweenCondition( betweenCondition, probeType_, @@ -303,7 +303,7 @@ void IndexLookupJoin::initLookupInput() { if (const auto equalCondition = std::dynamic_pointer_cast( - lookupCondition)) { + joinCondition)) { // Process an equal join condition by validating that the value is // constant. Equal conditions only support constant values for filtering. VELOX_USER_CHECK( @@ -371,15 +371,63 @@ void IndexLookupJoin::initOutputProjections() { } lookupOutputProjections_.emplace_back(i, outputChannelOpt.value()); } - if (includeMatchColumn_) { + if (hasMarker_) { matchOutputChannel_ = outputType_->size() - 1; } + VELOX_USER_CHECK_EQ( probeOutputProjections_.size() + lookupOutputProjections_.size() + !!matchOutputChannel_.has_value(), outputType_->size()); } +void IndexLookupJoin::initFilter() { + VELOX_CHECK_NULL(filter_); + + if (joinNode_->filter() == nullptr) { + return; + } + + std::vector filters = {joinNode_->filter()}; + filter_ = + std::make_unique(std::move(filters), operatorCtx_->execCtx()); + + std::vector names; + std::vector types; + const auto numFields = filter_->expr(0)->distinctFields().size(); + names.reserve(numFields); + types.reserve(numFields); + + column_index_t filterChannel{0}; + const auto addChannel = [&](column_index_t channel, + const RowTypePtr& inputType, + std::vector& projections) { + names.emplace_back(inputType->nameOf(channel)); + types.emplace_back(inputType->childAt(channel)); + projections.emplace_back(channel, filterChannel++); + }; + + for (const auto& field : filter_->expr(0)->distinctFields()) { + const auto& name = field->field(); + auto channel = probeType_->getChildIdxIfExists(name); + if (channel.has_value()) { + addChannel(channel.value(), probeType_, filterProbeInputProjections_); + continue; + } + channel = lookupOutputType_->getChildIdxIfExists(name); + if (channel.has_value()) { + addChannel( + channel.value(), lookupOutputType_, filterLookupInputProjections_); + continue; + } + VELOX_FAIL( + "Index lookup join filter field not found in either left or right input: {}", + field->toString()); + } + + filterInputType_ = ROW(std::move(names), std::move(types)); +} + bool IndexLookupJoin::startDrain() { return numInputBatches() != 0; } @@ -579,11 +627,7 @@ RowVectorPtr IndexLookupJoin::getOutputFromLookupResult( batch.lookupFuture = ContinueFuture::makeEmpty(); if (batch.lookupInput->size() == 0) { - if (hasRemainingOutputForLeftJoin(batch)) { - return produceRemainingOutputForLeftJoin(batch); - } - finishInput(batch); - return nullptr; + return produceRemainingOutput(batch); } VELOX_CHECK_NOT_NULL(batch.lookupResultIter); @@ -609,6 +653,13 @@ RowVectorPtr IndexLookupJoin::getOutputFromLookupResult( prepareLookupResult(batch); VELOX_CHECK_NOT_NULL(batch.lookupResult); + if (!applyFilterOnLookupResult(batch)) { + VELOX_CHECK_NULL(batch.lookupResult); + // All rows in lookup result are filtered out, and fetch next lookup result + // batch. + return nullptr; + } + SCOPE_EXIT { maybeFinishLookupResult(batch); }; @@ -618,6 +669,14 @@ RowVectorPtr IndexLookupJoin::getOutputFromLookupResult( return produceOutputForLeftJoin(batch); } +RowVectorPtr IndexLookupJoin::produceRemainingOutput(InputBatchState& batch) { + if (hasRemainingOutputForLeftJoin(batch)) { + return produceRemainingOutputForLeftJoin(batch); + } + finishInput(batch); + return nullptr; +} + void IndexLookupJoin::prepareLookupResult(InputBatchState& batch) { VELOX_CHECK_NOT_NULL(batch.lookupResult); if (rawLookupInputHitIndices_ != nullptr) { @@ -630,28 +689,8 @@ void IndexLookupJoin::prepareLookupResult(InputBatchState& batch) { return; } VELOX_CHECK_NOT_NULL(batch.nonNullInputMappings); - vector_size_t* rawLookupInputHitIndices{nullptr}; - if (batch.lookupResult->inputHits->isMutable()) { - rawLookupInputHitIndices = - batch.lookupResult->inputHits->asMutable(); - } else { - const auto indicesByteSize = - batch.lookupResult->size() * sizeof(vector_size_t); - if ((batch.resultInputHitIndices == nullptr) || - !batch.resultInputHitIndices->unique() || - (batch.resultInputHitIndices->capacity() < indicesByteSize)) { - batch.resultInputHitIndices = allocateIndices(indicesByteSize, pool()); - } else { - batch.resultInputHitIndices->setSize(indicesByteSize); - } - rawLookupInputHitIndices = - batch.resultInputHitIndices->asMutable(); - std::memcpy( - rawLookupInputHitIndices, - batch.lookupResult->inputHits->as(), - indicesByteSize); - batch.lookupResult->inputHits = batch.resultInputHitIndices; - } + vector_size_t* rawLookupInputHitIndices = + batch.ensureInputHitsWritable(pool()); for (auto i = 0; i < batch.lookupResult->size(); ++i) { rawLookupInputHitIndices[i] = batch.rawNonNullInputMappings[rawLookupInputHitIndices[i]]; @@ -665,15 +704,44 @@ void IndexLookupJoin::prepareLookupResult(InputBatchState& batch) { rawLookupInputHitIndices_ = rawLookupInputHitIndices; } +vector_size_t* IndexLookupJoin::InputBatchState::ensureInputHitsWritable( + memory::MemoryPool* pool) { + VELOX_CHECK_NOT_NULL(lookupResult); + if (lookupResult->inputHits->isMutable()) { + return lookupResult->inputHits->asMutable(); + } + + const auto indicesByteSize = lookupResult->size() * sizeof(vector_size_t); + if ((resultInputHitIndices == nullptr) || + !resultInputHitIndices->isMutable() || + (resultInputHitIndices->capacity() < indicesByteSize)) { + resultInputHitIndices = allocateIndices(indicesByteSize, pool); + } else { + resultInputHitIndices->setSize(indicesByteSize); + } + auto* rawLookupInputHitIndices = + resultInputHitIndices->asMutable(); + std::memcpy( + rawLookupInputHitIndices, + lookupResult->inputHits->as(), + indicesByteSize); + lookupResult->inputHits = resultInputHitIndices; + return rawLookupInputHitIndices; +} + void IndexLookupJoin::maybeFinishLookupResult(InputBatchState& batch) { VELOX_CHECK_NOT_NULL(batch.lookupResult); if (nextOutputResultRow_ == batch.lookupResult->size()) { - batch.lookupResult = nullptr; - nextOutputResultRow_ = 0; - rawLookupInputHitIndices_ = nullptr; + finishLookupResult(batch); } } +void IndexLookupJoin::finishLookupResult(InputBatchState& batch) { + batch.lookupResult = nullptr; + nextOutputResultRow_ = 0; + rawLookupInputHitIndices_ = nullptr; +} + bool IndexLookupJoin::hasRemainingOutputForLeftJoin( const InputBatchState& batch) const { if (joinType_ != core::JoinType::kLeft) { @@ -775,7 +843,7 @@ void IndexLookupJoin::fillOutputMatchRows( offset, offset + size, match ? bits::kNotNull : bits::kNull); - if (!includeMatchColumn_) { + if (!hasMarker_) { return; } VELOX_CHECK_NOT_NULL(rawMatchValues_); @@ -806,7 +874,11 @@ RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( for (; numOutputRows < maxOutputRows && nextOutputResultRow_ < batch.lookupResult->size();) { VELOX_CHECK_GE( - rawLookupInputHitIndices_[nextOutputResultRow_], lastProcessedInputRow); + rawLookupInputHitIndices_[nextOutputResultRow_], + lastProcessedInputRow, + "nextOutputResultRow_ {}, batch.lookupResult->size() {}", + nextOutputResultRow_, + batch.lookupResult->size()); const vector_size_t numMissedInputRows = rawLookupInputHitIndices_[nextOutputResultRow_] - lastProcessedInputRow - 1; @@ -883,7 +955,7 @@ RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( numOutputRows, batch.lookupResult->output->childAt(projection.inputChannel)); } - if (includeMatchColumn_) { + if (hasMarker_) { output_->childAt(matchOutputChannel_.value()) = matchColumn_; } } else { @@ -900,7 +972,7 @@ RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( ->slice(startOutputRow, numOutputRows); } } - if (includeMatchColumn_) { + if (hasMarker_) { output_->childAt(matchOutputChannel_.value()) = BaseVector::createConstant(BOOLEAN(), true, numOutputRows, pool()); } @@ -909,7 +981,7 @@ RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( } void IndexLookupJoin::ensureMatchColumn(vector_size_t maxOutputRows) { - if (!includeMatchColumn_) { + if (!hasMarker_) { return; } if (matchColumn_) { @@ -925,7 +997,7 @@ void IndexLookupJoin::ensureMatchColumn(vector_size_t maxOutputRows) { } void IndexLookupJoin::setMatchColumnSize(vector_size_t numOutputRows) { - if (!includeMatchColumn_) { + if (!hasMarker_) { return; } VELOX_CHECK_NOT_NULL(matchColumn_); @@ -964,7 +1036,7 @@ RowVectorPtr IndexLookupJoin::produceRemainingOutputForLeftJoin( numOutputRows, pool()); } - if (includeMatchColumn_) { + if (hasMarker_) { output_->childAt(matchOutputChannel_.value()) = BaseVector::createConstant(BOOLEAN(), false, numOutputRows, pool()); } @@ -1017,6 +1089,111 @@ void IndexLookupJoin::close() { Operator::close(); } +bool IndexLookupJoin::applyFilterOnLookupResult(InputBatchState& batch) { + VELOX_CHECK_NOT_NULL(batch.lookupResult); + if (!filter_) { + return true; + } + if (batch.lookupResult->size() == 0) { + return true; + } + + const auto numResultRows = batch.lookupResult->size(); + + // Prepare filter input vector + filterRows_.resize(numResultRows); + filterRows_.setAll(); + + if (!filterInput_) { + filterInput_ = + BaseVector::create(filterInputType_, numResultRows, pool()); + } else { + VectorPtr filterInputVector = std::move(filterInput_); + BaseVector::prepareForReuse(filterInputVector, numResultRows); + filterInput_ = std::static_pointer_cast(filterInputVector); + } + + // Populate filter input from probe input. + for (const auto& projection : filterProbeInputProjections_) { + // Get the probe input column and dictionary-wrap it with hit indices + filterInput_->childAt(projection.outputChannel) = + BaseVector::wrapInDictionary( + nullptr, + batch.lookupResult->inputHits, + numResultRows, + batch.input->childAt(projection.inputChannel)); + } + + // Populate filter input from lookup result. + for (const auto& projection : filterLookupInputProjections_) { + filterInput_->childAt(projection.outputChannel) = + batch.lookupResult->output->childAt(projection.inputChannel); + } + + // Evaluate filter + filterResult_.resize(1); + EvalCtx evalCtx(operatorCtx_->execCtx(), filter_.get(), filterInput_.get()); + filter_->eval(filterRows_, evalCtx, filterResult_); + decodedFilterResult_.decode(*filterResult_[0], filterRows_); + + const auto indicesByteSize = numResultRows * sizeof(vector_size_t); + if (!filteredIndices_ || !filteredIndices_->isMutable() || + filteredIndices_->capacity() < indicesByteSize) { + filteredIndices_ = allocateIndices(numResultRows, pool()); + } else { + filteredIndices_->setSize(indicesByteSize); + } + auto* rawFilteredIndices = filteredIndices_->asMutable(); + + vector_size_t numPassed{0}; + for (auto i = 0; i < numResultRows; ++i) { + if (!decodedFilterResult_.isNullAt(i) && + decodedFilterResult_.valueAt(i)) { + rawFilteredIndices[numPassed++] = i; + } + } + + if (numPassed == 0) { + finishLookupResult(batch); + return false; + } + + if (numPassed == numResultRows) { + return true; + } + + // Some rows passed - create filtered lookup result. + filteredIndices_->setSize(numPassed * sizeof(vector_size_t)); + + // Update the inputHits buffer. + auto* rawLookupInputHitIndices = batch.ensureInputHitsWritable(pool()); + for (auto i = 0; i < numPassed; ++i) { + rawLookupInputHitIndices[i] = + rawLookupInputHitIndices_[rawFilteredIndices[i]]; +#ifdef NDEBUG + if (i > 0) { + VELOX_DCHECK_LE( + rawLookupInputHitIndices[i - 1], rawLookupInputHitIndices[i]); + } +#endif + } + batch.lookupResult->inputHits->setSize(numPassed * sizeof(vector_size_t)); + rawLookupInputHitIndices_ = rawLookupInputHitIndices; + + // Create the filtered result vector. + auto filteredOutput = BaseVector::create( + batch.lookupResult->output->type(), numPassed, pool()); + for (auto i = 0; i < batch.lookupResult->output->childrenSize(); ++i) { + filteredOutput->childAt(i) = BaseVector::wrapInDictionary( + nullptr, + filteredIndices_, + numPassed, + batch.lookupResult->output->childAt(i)); + } + batch.lookupResult->output = std::move(filteredOutput); + return true; +} + void IndexLookupJoin::recordConnectorStats() { if (indexSource_ == nullptr) { // NOTE: index join might fail to create index source so skip record stats diff --git a/velox/exec/IndexLookupJoin.h b/velox/exec/IndexLookupJoin.h index 159324b7bb79..2d2f43e60380 100644 --- a/velox/exec/IndexLookupJoin.h +++ b/velox/exec/IndexLookupJoin.h @@ -131,6 +131,13 @@ class IndexLookupJoin : public Operator { bool empty() const { return input == nullptr; } + + // Ensures that the lookup result's inputHits buffer is writable and returns + // a mutable pointer. If the buffer is already mutable, returns it directly. + // Otherwise, creates a new writable buffer by copying the existing data and + // returns a pointer to the new buffer. This is needed when filters or null + // key handling requires modifying the input hit indices. + vector_size_t* ensureInputHitsWritable(memory::MemoryPool* pool); }; void initInputBatches(); @@ -138,6 +145,13 @@ class IndexLookupJoin : public Operator { void initLookupInput(); void initLookupOutput(); void initOutputProjections(); + void initFilter(); + + // Applies the join filter directly on the lookup result, updating the + // lookup result to only include rows that pass the filter. Returns true if + // some rows passed the filter, otherwise false. + bool applyFilterOnLookupResult(InputBatchState& batch); + void ensureInputLoaded(const InputBatchState& batch); // Prepare index source lookup for a given 'input_'. void prepareLookup(InputBatchState& batch); @@ -149,6 +163,11 @@ class IndexLookupJoin : public Operator { RowVectorPtr getOutputFromLookupResult(InputBatchState& batch); RowVectorPtr produceOutputForInnerJoin(const InputBatchState& batch); RowVectorPtr produceOutputForLeftJoin(const InputBatchState& batch); + // Handles production of remaining output after lookup result processing is + // complete. For left joins, this ensures unmatched rows from the probe side + // are included in the output with null values for lookup columns. For inner + // joins, this simply finishes the input batch. + RowVectorPtr produceRemainingOutput(InputBatchState& batch); // Produces output for the remaining input rows that has no matches from the // lookup at the end of current input batch processing. RowVectorPtr produceRemainingOutputForLeftJoin(const InputBatchState& batch); @@ -160,8 +179,10 @@ class IndexLookupJoin : public Operator { bool hasRemainingOutputForLeftJoin(const InputBatchState& batch) const; // Checks if we have finished processing the current 'lookupResult_'. If so, - // we reset 'lookupResult_' and corresponding processing state. + // call 'finishLookupResult' to reset 'lookupResult_' and corresponding + // processing state. void maybeFinishLookupResult(InputBatchState& batch); + void finishLookupResult(InputBatchState& batch); // Invoked after finished processing the current 'input_' batch. The function // resets the input batch and the lookup result states. @@ -232,12 +253,12 @@ class IndexLookupJoin : public Operator { const vector_size_t outputBatchSize_; // Type of join. const core::JoinType joinType_; - const bool includeMatchColumn_; + const bool hasMarker_; const size_t numKeys_; const RowTypePtr probeType_; const RowTypePtr lookupType_; const connector::ConnectorTableHandlePtr lookupTableHandle_; - const std::vector lookupConditions_; + const std::vector joinConditions_; const connector::ColumnHandleMap lookupColumnHandles_; const std::shared_ptr connectorQueryCtx_; const std::shared_ptr connector_; @@ -300,6 +321,24 @@ class IndexLookupJoin : public Operator { BufferPtr lookupOutputNulls_; uint64_t* rawLookupOutputNulls_{nullptr}; + // Join filter. + std::unique_ptr filter_; + + // Join filter input type. + RowTypePtr filterInputType_; + + // Maps probe-side input channels to channels in 'filterInputType_'. + std::vector filterProbeInputProjections_; + // Maps lookup-side input channels to channels in 'filterInputType_', + std::vector filterLookupInputProjections_; + + // Reusable memory for filter evaluations. + RowVectorPtr filterInput_; + SelectivityVector filterRows_; + std::vector filterResult_; + DecodedVector decodedFilterResult_; + BufferPtr filteredIndices_; + // The reusable output vector for the join output. RowVectorPtr output_; FlatVectorPtr matchColumn_{nullptr}; diff --git a/velox/exec/Merge.cpp b/velox/exec/Merge.cpp index 887fec7017eb..2c234e6aded0 100644 --- a/velox/exec/Merge.cpp +++ b/velox/exec/Merge.cpp @@ -15,6 +15,8 @@ */ #include "velox/exec/Merge.h" +#include +#include #include "velox/common/testutil/TestValue.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" @@ -22,19 +24,6 @@ using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { -namespace { -std::unique_ptr getVectorSerdeOptions( - const core::QueryConfig& queryConfig, - VectorSerde::Kind kind) { - std::unique_ptr options = - kind == VectorSerde::Kind::kPresto - ? std::make_unique() - : std::make_unique(); - options->compressionKind = - common::stringToCompressionKind(queryConfig.shuffleCompressionKind()); - return options; -} -} // namespace Merge::Merge( int32_t operatorId, @@ -562,14 +551,21 @@ void SpillMerger::start() { RowVectorPtr SpillMerger::getOutput( std::vector& sourceBlockingFutures, - bool& atEnd) const { + bool& atEnd) { TestValue::adjust( "facebook::velox::exec::SpillMerger::getOutput", &sourceBlockingFutures); sourceMerger_->isBlocked(sourceBlockingFutures); if (!sourceBlockingFutures.empty()) { return nullptr; } - return sourceMerger_->getOutput(sourceBlockingFutures, atEnd); + // SpillMerger::getOutput waits for all readers to finish, reaches EOF, + // and rethrows any captured error. Centralizing error propagation here + // helps prevent potential resource leaks. + auto output = sourceMerger_->getOutput(sourceBlockingFutures, atEnd); + if (atEnd) { + checkError(); + } + return output; } std::vector> SpillMerger::createMergeSources( @@ -616,61 +612,97 @@ std::unique_ptr SpillMerger::createSourceMerger( type, std::move(streams), maxOutputBatchRows, maxOutputBatchBytes, pool); } -// static. -void SpillMerger::asyncReadFromSpillFileStream( +void SpillMerger::finishSource(size_t streamIdx) const { + ContinueFuture future{ContinueFuture::makeEmpty()}; + sources_[streamIdx]->enqueue(nullptr, &future); + VELOX_CHECK(!future.valid()); +} + +void SpillMerger::readFromSpillFileStream( const std::weak_ptr& mergeHolder, size_t streamIdx) { TestValue::adjust( - "facebook::velox::exec::SpillMerger::asyncReadFromSpillFileStream", - static_cast(0)); + "facebook::velox::exec::SpillMerger::readFromSpillFileStream", nullptr); const auto merger = mergeHolder.lock(); if (merger == nullptr) { LOG(ERROR) << "SpillMerger is destroyed, abandon reading from batch stream"; return; } - merger->readFromSpillFileStream(streamIdx); -} -void SpillMerger::readFromSpillFileStream(size_t streamIdx) { - RowVectorPtr vector; - ContinueFuture future{ContinueFuture::makeEmpty()}; - if (!batchStreams_[streamIdx]->nextBatch(vector)) { - VELOX_CHECK_NULL(vector); - sources_[streamIdx]->enqueue(nullptr, &future); - VELOX_CHECK(!future.valid()); - return; - } - const auto blockingReason = - sources_[streamIdx]->enqueue(std::move(vector), &future); - // TODO: add async error handling. - if (blockingReason == BlockingReason::kNotBlocked) { - VELOX_CHECK(!future.valid()); - executor_->add( - [mergeHolder = std::weak_ptr(shared_from_this()), streamIdx]() { - asyncReadFromSpillFileStream(mergeHolder, streamIdx); - }); - } else { - VELOX_CHECK(future.valid()); - std::move(future) - .via(executor_) - .thenValue([mergeHolder = std::weak_ptr(shared_from_this()), - streamIdx](folly::Unit) { - asyncReadFromSpillFileStream(mergeHolder, streamIdx); - }) - .thenError( - folly::tag_t{}, - [streamIdx](const std::exception& e) { - LOG(ERROR) << "Stop the " << streamIdx - << "th batch stream producer on error: " << e.what(); - }); + try { + if (hasError()) { + finishSource(streamIdx); + return; + } + + RowVectorPtr vector; + if (!batchStreams_[streamIdx]->nextBatch(vector)) { + VELOX_CHECK_NULL(vector); + finishSource(streamIdx); + return; + } + + ContinueFuture future{ContinueFuture::makeEmpty()}; + const auto blockingReason = + sources_[streamIdx]->enqueue(std::move(vector), &future); + if (blockingReason == BlockingReason::kNotBlocked) { + VELOX_CHECK(!future.valid()); + readFromSpillFileStream(mergeHolder, streamIdx); + } else { + VELOX_CHECK(future.valid()); + std::move(future) + .via(executor_) + .thenValue([this, mergeHolder, streamIdx](auto&&) { + readFromSpillFileStream(mergeHolder, streamIdx); + }) + .thenError( + folly::tag_t{}, + [this, mergeHolder, streamIdx](const std::exception& e) { + const auto merger = mergeHolder.lock(); + if (merger != nullptr) { + LOG(ERROR) << "Stop the " << streamIdx + << " th source on error: " << e.what(); + setError(std::make_exception_ptr(e)); + finishSource(streamIdx); + } + }); + } + } catch (const std::exception& e) { + LOG(ERROR) << "The " << streamIdx + << " spill stream failed with error: " << e.what(); + setError(std::current_exception()); + finishSource(streamIdx); } } void SpillMerger::scheduleAsyncSpillFileStreamReads() { VELOX_CHECK_EQ(batchStreams_.size(), sources_.size()); for (auto i = 0; i < batchStreams_.size(); ++i) { - executor_->add( - [&, streamIdx = i]() { readFromSpillFileStream(streamIdx); }); + executor_->add([&, streamIdx = i]() { + readFromSpillFileStream(std::weak_ptr(shared_from_this()), streamIdx); + }); + } +} + +void SpillMerger::setError(const std::exception_ptr& exception) { + std::lock_guard l(mutex_); + if (exception_ != nullptr) { + return; + } + exception_ = exception; +} + +bool SpillMerger::hasError() const { + std::lock_guard l(mutex_); + return exception_ != nullptr; +} + +void SpillMerger::checkError() { + if (hasError()) { + sourceMerger_.reset(); + batchStreams_.clear(); + sources_.clear(); + std::rethrow_exception(exception_); } } @@ -725,7 +757,8 @@ MergeExchange::MergeExchange( "MergeExchange"), serde_(getNamedVectorSerde(mergeExchangeNode->serdeKind())), serdeOptions_(getVectorSerdeOptions( - driverCtx->queryConfig(), + common::stringToCompressionKind( + driverCtx->queryConfig().shuffleCompressionKind()), mergeExchangeNode->serdeKind())) {} BlockingReason MergeExchange::addMergeSources(ContinueFuture* future) { diff --git a/velox/exec/Merge.h b/velox/exec/Merge.h index f41dfebd94fb..b688b42cf268 100644 --- a/velox/exec/Merge.h +++ b/velox/exec/Merge.h @@ -15,11 +15,11 @@ */ #pragma once +#include "velox/common/base/TreeOfLosers.h" #include "velox/exec/Exchange.h" #include "velox/exec/MergeSource.h" #include "velox/exec/Spill.h" #include "velox/exec/Spiller.h" -#include "velox/exec/TreeOfLosers.h" namespace facebook::velox::exec { @@ -302,7 +302,7 @@ class SpillMerger : public std::enable_shared_from_this { RowVectorPtr getOutput( std::vector& sourceBlockingFutures, - bool& atEnd) const; + bool& atEnd); private: static std::vector> createMergeSources( @@ -321,14 +321,23 @@ class SpillMerger : public std::enable_shared_from_this { uint64_t maxOutputBatchBytes, velox::memory::MemoryPool* pool); - static void asyncReadFromSpillFileStream( + void finishSource(size_t streamIdx) const; + + void readFromSpillFileStream( const std::weak_ptr& mergeHolder, size_t streamIdx); - void readFromSpillFileStream(size_t streamIdx); - void scheduleAsyncSpillFileStreamReads(); + // Sets 'exception_' when an async reader throws. + void setError(const std::exception_ptr& exception); + + // Returns true if any async reader has thrown an exception. + bool hasError() const; + + // If any async reader has thrown an exception, rethrows it. + void checkError(); + folly::Executor* const executor_; const std::shared_ptr> spillStats_; const std::shared_ptr pool_; @@ -336,6 +345,8 @@ class SpillMerger : public std::enable_shared_from_this { std::vector> sources_; std::vector> batchStreams_; std::unique_ptr sourceMerger_; + mutable std::timed_mutex mutex_; + std::exception_ptr exception_ = nullptr; }; // LocalMerge merges its source's output into a single stream of diff --git a/velox/exec/OneWayStatusFlag.h b/velox/exec/OneWayStatusFlag.h index 9986f5bafe76..4eea37c24713 100644 --- a/velox/exec/OneWayStatusFlag.h +++ b/velox/exec/OneWayStatusFlag.h @@ -16,53 +16,28 @@ #pragma once -#include #include namespace facebook::velox::exec { -/// A simple one way status flag that uses a non atomic flag to avoid -/// unnecessary atomic operations. class OneWayStatusFlag { public: - bool check() const { -#if defined(__x86_64__) - folly::annotate_ignore_thread_sanitizer_guard g(__FILE__, __LINE__); - return fastStatus_ || atomicStatus_.load(); -#else - return atomicStatus_.load(std::memory_order_relaxed) || - atomicStatus_.load(); -#endif + bool check() const noexcept { + return status_.load(std::memory_order_acquire); } - void set() { -#if defined(__x86_64__) - folly::annotate_ignore_thread_sanitizer_guard g(__FILE__, __LINE__); - if (!fastStatus_) { - atomicStatus_.store(true); - fastStatus_ = true; + void set() noexcept { + if (!status_.load(std::memory_order_relaxed)) { + status_.store(true, std::memory_order_release); } -#else - if (!atomicStatus_.load(std::memory_order_relaxed)) { - atomicStatus_.store(true); - } -#endif } - /// Operator overload to convert OneWayStatusFlag to bool - operator bool() const { + explicit operator bool() const noexcept { return check(); } private: -#if defined(__x86_64__) - // This flag can only go from false to true, and is only checked at the end of - // a loop. Given that once a flag is true it can never go back to false, we - // are ok to use this in a non synchronized manner to avoid the overhead. As - // such we consciously exempt ourselves here from TSAN detection. - bool fastStatus_{false}; -#endif - std::atomic_bool atomicStatus_{false}; + std::atomic_bool status_{false}; }; } // namespace facebook::velox::exec diff --git a/velox/exec/Operator.cpp b/velox/exec/Operator.cpp index b5dcdb849c72..be45e716da59 100644 --- a/velox/exec/Operator.cpp +++ b/velox/exec/Operator.cpp @@ -72,6 +72,8 @@ OperatorCtx::createConnectorQueryCtx( task->queryCtx()->fsTokenProvider()); connectorQueryCtx->setSelectiveNimbleReaderEnabled( driverCtx_->queryConfig().selectiveNimbleReaderEnabled()); + connectorQueryCtx->setRowSizeTrackingMode( + driverCtx_->queryConfig().rowSizeTrackingMode()); return connectorQueryCtx; } diff --git a/velox/exec/OperatorUtils.cpp b/velox/exec/OperatorUtils.cpp index eb7258202a95..9bea1965ea4a 100644 --- a/velox/exec/OperatorUtils.cpp +++ b/velox/exec/OperatorUtils.cpp @@ -14,8 +14,10 @@ * limitations under the License. */ #include "velox/exec/OperatorUtils.h" +#include "velox/exec/PartitionedOutput.h" #include "velox/exec/VectorHasher.h" #include "velox/expression/EvalCtx.h" +#include "velox/serializers/PrestoSerializer.h" #include "velox/vector/ConstantVector.h" #include "velox/vector/FlatVector.h" #include "velox/vector/LazyVector.h" @@ -569,4 +571,20 @@ std::unique_ptr BlockedOperatorFactory::toOperator( } return nullptr; } + +std::unique_ptr getVectorSerdeOptions( + common::CompressionKind compressionKind, + VectorSerde::Kind kind, + std::optional minCompressionRatio) { + std::unique_ptr options = + kind == VectorSerde::Kind::kPresto + ? std::make_unique() + : std::make_unique(); + options->compressionKind = compressionKind; + if (minCompressionRatio.has_value()) { + options->minCompressionRatio = minCompressionRatio.value(); + } + return options; +} + } // namespace facebook::velox::exec diff --git a/velox/exec/OperatorUtils.h b/velox/exec/OperatorUtils.h index ac2698e437de..605edba4a288 100644 --- a/velox/exec/OperatorUtils.h +++ b/velox/exec/OperatorUtils.h @@ -15,8 +15,11 @@ */ #pragma once +#include +#include "velox/core/QueryConfig.h" #include "velox/exec/Operator.h" #include "velox/exec/Spiller.h" +#include "velox/vector/VectorStream.h" namespace facebook::velox::exec { @@ -307,4 +310,12 @@ class BlockedOperatorFactory : public Operator::PlanNodeTranslator { private: BlockedOperatorCb blockedCb_{nullptr}; }; + +/// Creates VectorSerde::Options for the given VectorSerde kind with compression +/// settings. Optionally configures minimum compression ratio. +std::unique_ptr getVectorSerdeOptions( + common::CompressionKind compressionKind, + VectorSerde::Kind kind, + std::optional minCompressionRatio = std::nullopt); + } // namespace facebook::velox::exec diff --git a/velox/exec/PartitionedOutput.cpp b/velox/exec/PartitionedOutput.cpp index 974cfbb3ec8f..587400713994 100644 --- a/velox/exec/PartitionedOutput.cpp +++ b/velox/exec/PartitionedOutput.cpp @@ -15,24 +15,11 @@ */ #include "velox/exec/PartitionedOutput.h" +#include "velox/exec/OperatorUtils.h" #include "velox/exec/OutputBufferManager.h" #include "velox/exec/Task.h" namespace facebook::velox::exec { -namespace { -std::unique_ptr getVectorSerdeOptions( - const core::QueryConfig& queryConfig, - VectorSerde::Kind kind) { - std::unique_ptr options = - kind == VectorSerde::Kind::kPresto - ? std::make_unique() - : std::make_unique(); - options->compressionKind = - common::stringToCompressionKind(queryConfig.shuffleCompressionKind()); - options->minCompressionRatio = PartitionedOutput::minCompressionRatio(); - return options; -} -} // namespace namespace detail { Destination::Destination( @@ -203,8 +190,11 @@ PartitionedOutput::PartitionedOutput( eagerFlush_(eagerFlush), serde_(getNamedVectorSerde(planNode->serdeKind())), serdeOptions_(getVectorSerdeOptions( - operatorCtx_->driverCtx()->queryConfig(), - planNode->serdeKind())) { + common::stringToCompressionKind(operatorCtx_->driverCtx() + ->queryConfig() + .shuffleCompressionKind()), + planNode->serdeKind(), + PartitionedOutput::minCompressionRatio())) { if (!planNode->isPartitioned()) { VELOX_USER_CHECK_EQ(numDestinations_, 1); } diff --git a/velox/exec/PlanNodeStats.cpp b/velox/exec/PlanNodeStats.cpp index ed3414fed7d3..76c47b3db399 100644 --- a/velox/exec/PlanNodeStats.cpp +++ b/velox/exec/PlanNodeStats.cpp @@ -59,6 +59,14 @@ PlanNodeStats& PlanNodeStats::operator+=(const PlanNodeStats& another) { } } + for (const auto& [name, exprStats] : another.expressionStats) { + auto const [it, inserted] = + this->expressionStats.try_emplace(name, exprStats); + if (!inserted) { + it->second.add(exprStats); + } + } + // Populating number of drivers for plan nodes with multiple operators is not // useful. Each operator could have been executed in different pipelines with // different number of drivers. diff --git a/velox/exec/SortWindowBuild.cpp b/velox/exec/SortWindowBuild.cpp index f25175cc2cfa..86b015f023ba 100644 --- a/velox/exec/SortWindowBuild.cpp +++ b/velox/exec/SortWindowBuild.cpp @@ -16,6 +16,7 @@ #include "velox/exec/SortWindowBuild.h" #include "velox/exec/MemoryReclaimer.h" +#include "velox/exec/Window.h" namespace facebook::velox::exec { @@ -45,16 +46,19 @@ SortWindowBuild::SortWindowBuild( common::PrefixSortConfig&& prefixSortConfig, const common::SpillConfig* spillConfig, tsan_atomic* nonReclaimableSection, + folly::Synchronized* opStats, folly::Synchronized* spillStats) : WindowBuild(node, pool, spillConfig, nonReclaimableSection), numPartitionKeys_{node->partitionKeys().size()}, compareFlags_{makeCompareFlags(numPartitionKeys_, node->sortingOrders())}, pool_(pool), prefixSortConfig_(prefixSortConfig), + opStats_(opStats), spillStats_(spillStats), sortedRows_(0, memory::StlAllocator(*pool)), partitionStartRows_(0, memory::StlAllocator(*pool)) { VELOX_CHECK_NOT_NULL(pool_); + VELOX_CHECK_NOT_NULL(opStats_); allKeyInfo_.reserve(partitionKeyInfo_.size() + sortKeyInfo_.size()); allKeyInfo_.insert( allKeyInfo_.cend(), partitionKeyInfo_.begin(), partitionKeyInfo_.end()); @@ -294,14 +298,31 @@ void SortWindowBuild::noMoreInput() { pool_->release(); } -void SortWindowBuild::loadNextPartitionFromSpill() { +void SortWindowBuild::loadNextPartitionBatchFromSpill() { + // Check if current partition batch still has available partitions. If so, + // return directly. + if (currentPartition_ < static_cast(partitionStartRows_.size() - 2)) { + return; + } + + const int minReadBatchRows = spillConfig_->windowMinReadBatchRows; sortedRows_.clear(); - sortedRows_.shrink_to_fit(); + sortedRows_.reserve(minReadBatchRows); data_->clear(); + partitionStartRows_.clear(); + partitionStartRows_.reserve(minReadBatchRows); + partitionStartRows_.push_back(0); + currentPartition_ = -1; + numSpillReadBatches_++; + // Load at least #minReadBatchRows rows and a complete partition. The rows + // might contain multiple partitions. Record the partition boundaries as + // inMemory case. In this way, the logic of getting window partitions would be + // identical between inMemory and spill. for (;;) { auto next = merge_->next(); if (next == nullptr) { + partitionStartRows_.push_back(sortedRows_.size()); break; } @@ -324,7 +345,10 @@ void SortWindowBuild::loadNextPartitionFromSpill() { } if (newPartition) { - break; + partitionStartRows_.push_back(sortedRows_.size()); + if (sortedRows_.size() >= minReadBatchRows) { + break; + } } auto* newRow = data_->newRow(); @@ -334,16 +358,19 @@ void SortWindowBuild::loadNextPartitionFromSpill() { sortedRows_.push_back(newRow); next->pop(); } -} -std::shared_ptr SortWindowBuild::nextPartition() { - if (merge_ != nullptr) { - VELOX_CHECK(!sortedRows_.empty(), "No window partitions available"); - auto partition = folly::Range(sortedRows_.data(), sortedRows_.size()); - return std::make_shared( - data_.get(), partition, inversedInputChannels_, sortKeyInfo_); + // No more partition batches. All data is consumed. + if (sortedRows_.empty()) { + partitionStartRows_.clear(); + numSpillReadBatches_--; + + auto lockedOpStats = opStats_->wlock(); + lockedOpStats->runtimeStats[Window::kWindowSpillReadNumBatches] = + RuntimeMetric(numSpillReadBatches_); } +} +std::shared_ptr SortWindowBuild::nextPartition() { VELOX_CHECK(!partitionStartRows_.empty(), "No window partitions available"); currentPartition_++; @@ -364,8 +391,7 @@ std::shared_ptr SortWindowBuild::nextPartition() { bool SortWindowBuild::hasNextPartition() { if (merge_ != nullptr) { - loadNextPartitionFromSpill(); - return !sortedRows_.empty(); + loadNextPartitionBatchFromSpill(); } return partitionStartRows_.size() > 0 && diff --git a/velox/exec/SortWindowBuild.h b/velox/exec/SortWindowBuild.h index 72875094007a..80167ffface3 100644 --- a/velox/exec/SortWindowBuild.h +++ b/velox/exec/SortWindowBuild.h @@ -32,6 +32,7 @@ class SortWindowBuild : public WindowBuild { common::PrefixSortConfig&& prefixSortConfig, const common::SpillConfig* spillConfig, tsan_atomic* nonReclaimableSection, + folly::Synchronized* opStats, folly::Synchronized* spillStats); ~SortWindowBuild() override { @@ -75,8 +76,10 @@ class SortWindowBuild : public WindowBuild { // Find the next partition start row from start. vector_size_t findNextPartitionStartRow(vector_size_t start); - // Reads next partition from spilled data into 'data_' and 'sortedRows_'. - void loadNextPartitionFromSpill(); + // Load the next partition batch if needed. If current partition batch is not + // entirely consumed, return directly. Otherwise, read next partition batch + // from spilled data into 'data_' and set pointers in 'sortedRows_'. + void loadNextPartitionBatchFromSpill(); const size_t numPartitionKeys_; @@ -92,6 +95,8 @@ class SortWindowBuild : public WindowBuild { // Config for Prefix-sort. const common::PrefixSortConfig prefixSortConfig_; + folly::Synchronized* const opStats_; + folly::Synchronized* const spillStats_; // allKeyInfo_ is a combination of (partitionKeyInfo_ and sortKeyInfo_). @@ -121,5 +126,8 @@ class SortWindowBuild : public WindowBuild { // Used to sort-merge spilled data. std::unique_ptr> merge_; + + // Number of batches of whole partitions read from spilled data. + uint64_t numSpillReadBatches_ = 0; }; } // namespace facebook::velox::exec diff --git a/velox/exec/SpatialIndex.cpp b/velox/exec/SpatialIndex.cpp new file mode 100644 index 000000000000..696b3dc52a73 --- /dev/null +++ b/velox/exec/SpatialIndex.cpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/exec/SpatialIndex.h" + +namespace facebook::velox::exec { + +SpatialIndex::SpatialIndex(std::vector envelopes) { + std::ranges::sort(envelopes, {}, &Envelope::minX); + + minXs_.reserve(envelopes.size()); + minYs_.reserve(envelopes.size()); + maxXs_.reserve(envelopes.size()); + maxYs_.reserve(envelopes.size()); + rowIndices_.reserve(envelopes.size()); + + for (const auto& env : envelopes) { + bounds_.maxX = std::max(bounds_.maxX, env.maxX); + bounds_.maxY = std::max(bounds_.maxY, env.maxY); + bounds_.minX = std::min(bounds_.minX, env.minX); + bounds_.minY = std::min(bounds_.minY, env.minY); + minXs_.push_back(env.minX); + minYs_.push_back(env.minY); + maxXs_.push_back(env.maxX); + maxYs_.push_back(env.maxY); + rowIndices_.push_back(env.rowIndex); + } +} + +std::vector SpatialIndex::query(const Envelope& queryEnv) const { + std::vector result; + if (!Envelope::intersects(queryEnv, bounds_)) { + return result; + } + + // Find the last minX that is <= queryEnv.maxX . These first envelopes + // are the only ones that can intersect the query envelope. + // `it` is _one past_ the last element, so we iterate up to it - 1. + auto it = std::upper_bound(minXs_.begin(), minXs_.end(), queryEnv.maxX); + if (it == minXs_.begin()) { + return result; + } + + auto lastIdx = std::distance(minXs_.begin(), it); + VELOX_CHECK_GT(lastIdx, 0); + + for (size_t idx = 0; idx < lastIdx; ++idx) { + bool intersects = (queryEnv.maxY >= minYs_[idx]) && + (queryEnv.minX <= maxXs_[idx]) && (queryEnv.minY <= maxYs_[idx]); + if (intersects) { + result.push_back(rowIndices_[idx]); + } + } + + return result; +} + +Envelope SpatialIndex::bounds() const { + return bounds_; +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/SpatialIndex.h b/velox/exec/SpatialIndex.h new file mode 100644 index 000000000000..1d1c8b5e0ca3 --- /dev/null +++ b/velox/exec/SpatialIndex.h @@ -0,0 +1,157 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace facebook::velox::exec { + +/// A minimal envelope for a geometry. +/// It also includes an index for the geometry for later reference. This can +/// be -1 if the geometry is not indexed. +/// +/// Envelopes use float32s instead of float64s so that SIMD loops can be +/// twice as fast. Our geometries use float64 coordinates, so we have to +/// downcast them for the envelope. The loss of precision is theoretically fine +/// because the envelope checks are already approximate: either they don't +/// intersect or they might intersect. Thus, expanding the envelopes slightly +/// does not affect correctness (but it might affect efficiency slightly). +/// +/// We want to show that if two envelopes expressed with float64 precision would +/// intersect, the envelopes with float32 precision would also intersect. +/// +/// Define +/// ``` +/// nextUp(f) = std::nextafter(f, std::numeric_limits::infinity()) +/// nextDown(f) = std::nextafter(f, -std::numeric_limits::infinity()) +/// ``` +/// which move a float up or down one ulp (unit in the last place). +/// +/// Since the conditions are all of the form `maxX >= minX` for float64s maxX +/// and minX, we need to show that this implies `nextUp((float) maxX) >= +/// nextDown((float) minX)`. +/// +/// Assume you have a double `d` and two adjacent floats `f0` and `f1`, such +/// that `d` is "between" `f0` and `f1`: +/// +/// 1. `(double) f0 <= d <= (double) f1` +/// 2. `nextup(f0) == f1 && f0 = nextdown(f0)` +/// +/// This implies `nextdown((float) d) <= f0 && nextup((float) d) >= f1`. +/// +/// Let double `minX` have two adjacent floats `f0`, `f1` as above, and `maxX` +/// have two adjacent floats `g0`, `g1`. Then +/// ``` +/// (double) nextDown((float) minX) +/// <= (double) f0 +/// <= minX +/// <= maxX +/// <= (double) g1 +/// <= (double) nextUp((float) maxX) +/// ``` +/// +/// And this implies `nextDown((float) minX) <= nextUp((float) maxX)` as +/// desired. The same argument applies to all for members, so if we construct +/// the float32 precision envelope by applying nextDown to the minX/Ys and +/// nextUp to maxX/Ys, the float32 envelope intersects in all cases that the +/// float64 envelope would (but not necessarily the converse). +struct Envelope { + float minX; + float minY; + float maxX; + float maxY; + int32_t rowIndex = -1; + + static inline bool intersects(const Envelope& left, const Envelope& right) { + return (left.maxX >= right.minX) && (left.minX <= right.maxX) && + (left.maxY >= right.minY) && (left.minY <= right.maxY); + } + + inline bool isEmpty() const { + return (maxX < minX) || (maxY < minY); + } + + static constexpr inline Envelope empty() { + return Envelope{ + .minX = std::numeric_limits::infinity(), + .minY = std::numeric_limits::infinity(), + .maxX = -std::numeric_limits::infinity(), + .maxY = -std::numeric_limits::infinity()}; + } + + static constexpr inline Envelope from( + double minX, + double minY, + double maxX, + double maxY, + int32_t rowIndex = -1) { + return Envelope{ + .minX = std::nextafterf( + static_cast(minX), -std::numeric_limits::infinity()), + .minY = std::nextafterf( + static_cast(minY), -std::numeric_limits::infinity()), + .maxX = std::nextafterf( + static_cast(maxX), std::numeric_limits::infinity()), + .maxY = std::nextafterf( + static_cast(maxY), std::numeric_limits::infinity()), + .rowIndex = rowIndex}; + } +}; + +/// A spatial index for a set of geometries. The index only cares about the +/// envelopes of the geometries, and an index into the geometries (not stored in +/// SpatialIndex). +/// +/// The contract is that SpatialIndex::probe returns the indices of all +/// envelopes that probeEnv intersects. The form of the index is an +/// implementation detail. The order of the returned indicies is an +/// implementation detail. +class SpatialIndex { + public: + SpatialIndex(const SpatialIndex&) = delete; + SpatialIndex& operator=(const SpatialIndex&) = delete; + + SpatialIndex() = default; + SpatialIndex(SpatialIndex&&) = default; + SpatialIndex& operator=(SpatialIndex&&) = default; + ~SpatialIndex() = default; + + explicit SpatialIndex(std::vector envelopes); + + /// Returns the row indices of all envelopes that probeEnv intersects. + /// Order of the returned indices is an implementation detail and cannot be + /// relied upon. + std::vector query(const Envelope& queryEnv) const; + + /// Returns the envelope of the all envelopes in the index. + /// The returned envelope will have index = -1. + Envelope bounds() const; + + private: + Envelope bounds_ = Envelope::empty(); + + std::vector minXs_{}; + std::vector minYs_{}; + std::vector maxXs_{}; + std::vector maxYs_{}; + std::vector rowIndices_{}; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/SpatialJoinProbe.cpp b/velox/exec/SpatialJoinProbe.cpp index bc49d24cba47..fd79060e71c4 100644 --- a/velox/exec/SpatialJoinProbe.cpp +++ b/velox/exec/SpatialJoinProbe.cpp @@ -42,6 +42,104 @@ std::vector extractProjections( } // namespace +////////////////// +// OUTPUT BUILDER + +void SpatialJoinOutputBuilder::initializeOutput( + const RowVectorPtr& input, + memory::MemoryPool* pool) { + if (output_ == nullptr) { + output_ = + BaseVector::create(outputType_, outputBatchSize_, pool); + } else { + VectorPtr outputVector = std::move(output_); + BaseVector::prepareForReuse(outputVector, outputBatchSize_); + output_ = std::static_pointer_cast(outputVector); + } + probeOutputIndices_ = allocateIndices(outputBatchSize_, pool); + rawProbeOutputIndices_ = probeOutputIndices_->asMutable(); + + // Add probe side projections as dictionary vectors + for (const auto& projection : probeProjections_) { + output_->childAt(projection.outputChannel) = wrapChild( + outputBatchSize_, + probeOutputIndices_, + input->childAt(projection.inputChannel)); + } + + // Add build side projections as uninitialized vectors + for (const auto& projection : buildProjections_) { + auto child = output_->childAt(projection.outputChannel); + if (child == nullptr) { + child = BaseVector::create( + outputType_->childAt(projection.outputChannel), + outputBatchSize_, + operatorCtx_.pool()); + } + } +} + +void SpatialJoinOutputBuilder::addOutputRow( + vector_size_t probeRow, + vector_size_t buildRow) { + VELOX_CHECK_NOT_NULL(probeOutputIndices_); + // Probe side is always a dictionary; just populate the index. + rawProbeOutputIndices_[outputRow_] = probeRow; + + // For the build side, we accumulate the ranges to copy, then copy all of + // them at once. Consecutive records are copied in one memcpy. + if (!buildCopyRanges_.empty() && + (buildCopyRanges_.back().sourceIndex + buildCopyRanges_.back().count) == + buildRow) { + ++buildCopyRanges_.back().count; + } else { + buildCopyRanges_.push_back({buildRow, outputRow_, 1}); + } + ++outputRow_; +} + +void SpatialJoinOutputBuilder::copyBuildValues( + const RowVectorPtr& buildVector) { + if (buildCopyRanges_.empty()) { + return; + } + + VELOX_CHECK_NOT_NULL(output_); + + for (const auto& projection : buildProjections_) { + const auto& buildChild = buildVector->childAt(projection.inputChannel); + const auto& outputChild = output_->childAt(projection.outputChannel); + outputChild->copyRanges(buildChild.get(), buildCopyRanges_); + } + buildCopyRanges_.clear(); +} + +void SpatialJoinOutputBuilder::addProbeMismatchRow(vector_size_t probeRow) { + VELOX_CHECK_NOT_NULL(output_); + + // Probe side is always a dictionary; just populate the index. + rawProbeOutputIndices_[outputRow_] = probeRow; + + // Null out build projections. + for (const auto& projection : buildProjections_) { + const auto& outputChild = output_->childAt(projection.outputChannel); + outputChild->setNull(outputRow_, true); + } + ++outputRow_; +} + +RowVectorPtr SpatialJoinOutputBuilder::takeOutput() { + VELOX_CHECK(buildCopyRanges_.empty()); + if (outputRow_ == 0 || !output_) { + return nullptr; + } + RowVectorPtr output = std::move(output_); + output->resize(outputRow_); + output_ = nullptr; + outputRow_ = 0; + return output; +} + SpatialJoinProbe::SpatialJoinProbe( int32_t operatorId, DriverCtx* driverCtx, @@ -54,22 +152,34 @@ SpatialJoinProbe::SpatialJoinProbe( "SpatialJoinProbe"), joinType_(joinNode->joinType()), outputBatchSize_{outputBatchRows()}, - joinNode_(joinNode) { - auto probeType = joinNode_->sources()[0]->outputType(); - auto buildType = joinNode_->sources()[1]->outputType(); - identityProjections_ = extractProjections(probeType, outputType_); - buildProjections_ = extractProjections(buildType, outputType_); + joinNode_(joinNode), + buildProjections_(extractProjections( + joinNode_->rightNode()->outputType(), + outputType_)), + outputBuilder_{ + outputBatchSize_, + outputType_, + extractProjections( + joinNode_->leftNode()->outputType(), + outputType_), // these are the identity Projections + buildProjections_, + *operatorCtx_} { + identityProjections_ = + extractProjections(joinNode_->leftNode()->outputType(), outputType_); } +///////// +// SETUP + void SpatialJoinProbe::initialize() { Operator::initialize(); - VELOX_CHECK(joinNode_ != nullptr); + VELOX_CHECK_NOT_NULL(joinNode_); if (joinNode_->joinCondition() != nullptr) { initializeFilter( joinNode_->joinCondition(), - joinNode_->sources()[0]->outputType(), - joinNode_->sources()[1]->outputType()); + joinNode_->leftNode()->outputType(), + joinNode_->rightNode()->outputType()); } joinNode_.reset(); @@ -154,29 +264,11 @@ void SpatialJoinProbe::close() { Operator::close(); } -void SpatialJoinProbe::addInput(RowVectorPtr input) { - VELOX_CHECK_NULL(input_); - - // In getOutput(), we are going to wrap input in dictionaries a few rows at a - // time. Since lazy vectors cannot be wrapped in different dictionaries, we - // are going to load them here. - for (auto& child : input->children()) { - child->loadedVector(); - } - input_ = std::move(input); - if (input_->size() > 0) { - probeSideEmpty_ = false; - } - VELOX_CHECK_EQ(buildIndex_, 0); -} - void SpatialJoinProbe::noMoreInput() { Operator::noMoreInput(); - if (state_ != ProbeOperatorState::kRunning || input_ != nullptr) { - return; + if (state_ == ProbeOperatorState::kRunning && input_ == nullptr) { + setState(ProbeOperatorState::kFinish); } - setState(ProbeOperatorState::kFinish); - return; } bool SpatialJoinProbe::getBuildData(ContinueFuture* future) { @@ -195,11 +287,49 @@ bool SpatialJoinProbe::getBuildData(ContinueFuture* future) { return true; } +void SpatialJoinProbe::checkStateTransition(ProbeOperatorState state) { + VELOX_CHECK_NE(state_, state); + switch (state) { + case ProbeOperatorState::kRunning: + VELOX_CHECK_EQ(state_, ProbeOperatorState::kWaitForBuild); + break; + case ProbeOperatorState::kWaitForBuild: + [[fallthrough]]; + case ProbeOperatorState::kFinish: + VELOX_CHECK_EQ(state_, ProbeOperatorState::kRunning); + break; + default: + VELOX_UNREACHABLE(probeOperatorStateName(state_)); + break; + } +} + +//////////////// +// INPUT/OUTPUT + +void SpatialJoinProbe::addInput(RowVectorPtr input) { + VELOX_CHECK_NULL(input_); + VELOX_CHECK_EQ(probeRow_, 0); + VELOX_CHECK(!probeHasMatch_); + VELOX_CHECK_EQ(buildIndex_, 0); + VELOX_CHECK_EQ(buildRow_, 0); + + // In getOutput(), we are going to wrap input in dictionaries a few rows at a + // time. Since lazy vectors cannot be wrapped in different dictionaries, we + // are going to load them here. + for (auto& child : input->children()) { + child->loadedVector(); + } + input_ = std::move(input); + ++probeCount_; +} + RowVectorPtr SpatialJoinProbe::getOutput() { if (state_ == ProbeOperatorState::kFinish || state_ == ProbeOperatorState::kWaitForPeers) { return nullptr; } + RowVectorPtr output{nullptr}; while (output == nullptr) { // Need more input. @@ -207,151 +337,102 @@ RowVectorPtr SpatialJoinProbe::getOutput() { break; } + // If the task owning this operator isn't running, there is no point + // to continue executing this procedure, which may be long in degenerate + // cases. Exit the working loop and let the Driver handle exiting + // gracefully in its own loop. + if (!operatorCtx_->task()->isRunning()) { + break; + } + + if (shouldYield()) { + break; + } + // Generate actual join output by processing probe and build matches, and // probe mismaches (for left joins). output = generateOutput(); } + + if (output != nullptr) { + ++outputCount_; + } return output; } RowVectorPtr SpatialJoinProbe::generateOutput() { - // If addToOutput() returns false, output_ is filled. Need to produce it. - if (!addToOutput()) { - VELOX_CHECK_GT(output_->size(), 0); - return std::move(output_); + VELOX_CHECK_NOT_NULL(input_); + VELOX_CHECK_GT(input_->size(), probeRow_); + outputBuilder_.initializeOutput(input_, pool()); + + while (!isOutputDone()) { + // Fill output_ with the results from one row. This may produce too + // much output and only partially complete. If so, the next time we + // call this we'll get the next chunk. + // + // addProbeRowOutput is responsible for advancing probeRow_. + addProbeRowOutput(); } - // Try to advance the probe cursor; call finish if no more probe input. - if (advanceProbe()) { + // If we've exhausted the input, release it. + if (probeRow_ >= input_->size()) { finishProbeInput(); - if (numOutputRows_ == 0) { - // output_ can only be re-used across probe rows within the same input_. - // Here we have to abandon the emtpy non-null output_ before we advance to - // the next probe input. - output_ = nullptr; - } } - if (!readyToProduceOutput()) { - return nullptr; - } - - output_->resize(numOutputRows_); - return std::move(output_); + return outputBuilder_.takeOutput(); } -bool SpatialJoinProbe::readyToProduceOutput() { - if (!output_ || numOutputRows_ == 0) { - return false; - } - - // If the input_ has no remaining rows or the output_ is fully filled, - // it's right time for output. - return !input_ || numOutputRows_ >= outputBatchSize_; -} - -bool SpatialJoinProbe::advanceProbe() { - if (hasProbedAllBuildData()) { - probeRow_ += 1; - probeRowHasMatch_ = false; - buildIndex_ = 0; - - // If we finished processing the probe side. - if (probeRow_ >= input_->size()) { - return true; +// Return true if adding output stops early because output is full. +void SpatialJoinProbe::addProbeRowOutput() { + VELOX_CHECK(buildVectors_.has_value()); + VELOX_CHECK(!outputBuilder_.isOutputFull()); + + while (!isProbeRowDone()) { + addBuildVectorOutput(buildVectors_.value()[buildIndex_]); + if (outputBuilder_.isOutputFull()) { + // If full, don't advance buildIndex_ because we may not have exhausted + // the current vector. Return instead of breaking so that we can add a + // mismatch row later if necessary. + return; } + advanceBuildVector(); } - return false; -} -bool SpatialJoinProbe::addToOutput() { - VELOX_CHECK_NOT_NULL(input_); - prepareOutput(); - - while (!hasProbedAllBuildData()) { - const auto& currentBuild = buildVectors_.value()[buildIndex_]; - - // Empty build vector; move to the next. - if (currentBuild->size() == 0) { - ++buildIndex_; - buildRow_ = 0; - continue; - } - - // Only re-calculate the filter if we have a new build vector. - if (buildRow_ == 0) { - evaluateSpatialJoinFilter(currentBuild); - } - - // Iterate over the filter results. For each match, add an output record. - for (vector_size_t i = buildRow_; i < decodedFilterResult_.size(); ++i) { - if (!isSpatialJoinConditionMatch(i)) { - continue; - } - - addOutputRow(i); - ++numOutputRows_; - probeRowHasMatch_ = true; - - // If the buffer is full, save state and produce it as output. - if (numOutputRows_ == outputBatchSize_) { - buildRow_ = i + 1; - copyBuildValues(currentBuild); - return false; - } - } - - // Before moving to the next build vector, copy the needed ranges. - copyBuildValues(currentBuild); - ++buildIndex_; - buildRow_ = 0; + // Now that we have finished the probe row, check if we need to add a probe + // mismatch record. + if (!probeHasMatch_ && needsProbeMismatch(joinType_)) { + outputBuilder_.addProbeMismatchRow(probeRow_); } - - // Check if the current probed row needs to be added as a mismatch (for left - // and full outer joins). - checkProbeMismatchRow(); - - // Signals that all input has been generated for the probeRow and build - // vectors; safe to move to the next probe record. - return true; + // Advance here instead of the loop in generateOutput so that early return on + // full doesn't advance the probe. + advanceProbeRow(); } -void SpatialJoinProbe::prepareOutput() { - if (output_ != nullptr) { - return; +void SpatialJoinProbe::addBuildVectorOutput(const RowVectorPtr& buildVector) { + if (FOLLY_UNLIKELY(buildRow_ == 0)) { + // Evaluate join filter for the whole vector just once. + evaluateJoinFilter(buildVector); } - std::vector localColumns(outputType_->size()); - - probeOutputIndices_ = allocateIndices(outputBatchSize_, pool()); - rawProbeOutputIndices_ = probeOutputIndices_->asMutable(); - - for (const auto& projection : identityProjections_) { - localColumns[projection.outputChannel] = BaseVector::wrapInDictionary( - {}, - probeOutputIndices_, - outputBatchSize_, - input_->childAt(projection.inputChannel)); - } + // Start where we left off: after the last buildRow_ that was processed. + while (!isBuildVectorDone(buildVector)) { + if (isJoinConditionMatch(buildRow_)) { + outputBuilder_.addOutputRow(probeRow_, buildRow_); + probeHasMatch_ = true; + } - // For other join types, add build side projections - for (const auto& projection : buildProjections_) { - localColumns[projection.outputChannel] = BaseVector::create( - outputType_->childAt(projection.outputChannel), - outputBatchSize_, - operatorCtx_->pool()); + // Advance buildRow_ even if full, since we're finished with this row. + ++buildRow_; } - numOutputRows_ = 0; - output_ = std::make_shared( - pool(), outputType_, nullptr, outputBatchSize_, std::move(localColumns)); + // Since we are copying from the current buildVector, we must copy here. + outputBuilder_.copyBuildValues(buildVector); } -void SpatialJoinProbe::evaluateSpatialJoinFilter( - const RowVectorPtr& buildVector) { +void SpatialJoinProbe::evaluateJoinFilter(const RowVectorPtr& buildVector) { // First step to process is to get a batch so we can evaluate the join // filter. - auto filterInput = getNextCrossProductBatch( + auto filterInput = getNextJoinBatch( buildVector, filterInputType_, filterProbeProjections_, @@ -371,22 +452,13 @@ void SpatialJoinProbe::evaluateSpatialJoinFilter( decodedFilterResult_.decode(*filterOutput_, filterInputRows_); } -RowVectorPtr SpatialJoinProbe::getNextCrossProductBatch( +RowVectorPtr SpatialJoinProbe::getNextJoinBatch( const RowVectorPtr& buildVector, const RowTypePtr& outputType, const std::vector& probeProjections, - const std::vector& buildProjections) { + const std::vector& buildProjections) const { VELOX_CHECK_GT(buildVector->size(), 0); - return genCrossProductMultipleBuildVectors( - buildVector, outputType, probeProjections, buildProjections); -} - -RowVectorPtr SpatialJoinProbe::genCrossProductMultipleBuildVectors( - const RowVectorPtr& buildVector, - const RowTypePtr& outputType, - const std::vector& probeProjections, - const std::vector& buildProjections) { std::vector projectedChildren(outputType->size()); const vector_size_t numOutputRows = buildVector->size(); @@ -404,68 +476,14 @@ RowVectorPtr SpatialJoinProbe::genCrossProductMultipleBuildVectors( pool(), outputType, nullptr, numOutputRows, std::move(projectedChildren)); } -void SpatialJoinProbe::addOutputRow(vector_size_t buildRow) { - // Probe side is always a dictionary; just populate the index. - rawProbeOutputIndices_[numOutputRows_] = probeRow_; - - // For the build side, we accumulate the ranges to copy, then copy all of them - // at once. If records are consecutive and can have a single copy range run. - if (!buildCopyRanges_.empty() && - (buildCopyRanges_.back().sourceIndex + buildCopyRanges_.back().count) == - buildRow) { - ++buildCopyRanges_.back().count; - } else { - buildCopyRanges_.push_back({buildRow, numOutputRows_, 1}); - } -} - -void SpatialJoinProbe::copyBuildValues(const RowVectorPtr& buildVector) { - if (buildCopyRanges_.empty() || isLeftSemiProjectJoin(joinType_)) { - return; - } - - for (const auto& projection : buildProjections_) { - const auto& buildChild = buildVector->childAt(projection.inputChannel); - const auto& outputChild = output_->childAt(projection.outputChannel); - outputChild->copyRanges(buildChild.get(), buildCopyRanges_); - } - buildCopyRanges_.clear(); -} - -void SpatialJoinProbe::checkProbeMismatchRow() { - // If we are processing the last batch of the build side, check if we need - // to add a probe mismatch record. - if (needsProbeMismatch(joinType_) && hasProbedAllBuildData() && - !probeRowHasMatch_) { - prepareOutput(); - addProbeMismatchRow(); - ++numOutputRows_; - } -} - -void SpatialJoinProbe::addProbeMismatchRow() { - // Probe side is always a dictionary; just populate the index. - rawProbeOutputIndices_[numOutputRows_] = probeRow_; - - // Null out build projections. - for (const auto& projection : buildProjections_) { - const auto& outputChild = output_->childAt(projection.outputChannel); - outputChild->setNull(numOutputRows_, true); - } -} - void SpatialJoinProbe::finishProbeInput() { VELOX_CHECK_NOT_NULL(input_); input_.reset(); - buildIndex_ = 0; probeRow_ = 0; - if (!noMoreInput_) { - return; + if (noMoreInput_) { + setState(ProbeOperatorState::kFinish); } - - setState(ProbeOperatorState::kFinish); - return; } } // namespace facebook::velox::exec diff --git a/velox/exec/SpatialJoinProbe.h b/velox/exec/SpatialJoinProbe.h index 54b2cbe701d4..6de342edb0aa 100644 --- a/velox/exec/SpatialJoinProbe.h +++ b/velox/exec/SpatialJoinProbe.h @@ -21,9 +21,62 @@ namespace facebook::velox::exec { +class SpatialJoinOutputBuilder { + public: + SpatialJoinOutputBuilder( + vector_size_t outputBatchSize, + RowTypePtr outputType, + std::vector probeProjections, + std::vector buildProjections, + const OperatorCtx& operatorCtx) + : outputBatchSize_{outputBatchSize}, + outputType_{std::move(outputType)}, + probeProjections_{std::move(probeProjections)}, + buildProjections_{std::move(buildProjections)}, + operatorCtx_{operatorCtx} { + VELOX_CHECK_GT(outputBatchSize_, 0); + } + + void initializeOutput(const RowVectorPtr& input, memory::MemoryPool* pool); + + bool isOutputFull() const { + return outputRow_ >= outputBatchSize_; + } + + void addOutputRow(vector_size_t probeRow, vector_size_t buildRow); + + /// Checks if it is required to add a probe mismatch row, and does it if + /// needed. The caller needs to ensure there is available space in `output_` + /// for the new record, which has nulled out build projections. + void addProbeMismatchRow(vector_size_t probeRow); + + void copyBuildValues(const RowVectorPtr& buildVector); + + RowVectorPtr takeOutput(); + + private: + // Initialization parameters + const vector_size_t outputBatchSize_; + const RowTypePtr outputType_; + const std::vector probeProjections_; + const std::vector buildProjections_; + const OperatorCtx& operatorCtx_; + + // Output state + RowVectorPtr output_; + vector_size_t outputRow_{0}; + // Dictionary indices for probe columns for output vector. + BufferPtr probeOutputIndices_; + // Mutable pointer to probeOutputIndices_ + vector_size_t* rawProbeOutputIndices_{}; + + // Stores the ranges of build values to be copied to the output vector (we + // batch them and copy once, instead of copying them row-by-row). + std::vector buildCopyRanges_{}; +}; + /// Implements a Spatial Join between records from the probe (input_) -/// and build (SpatialJoinBridge) sides. It supports inner, left, right and -/// full outer joins. +/// and build (SpatialJoinBridge) sides. It supports inner and left joins. /// /// This class is designed to evaluate spatial join conditions (e.g. /// ST_INTERSECTS, ST_CONTAINS, ST_WITHIN) between geometric data types. It can @@ -79,6 +132,13 @@ class SpatialJoinProbe : public Operator { void close() override; private: + void checkStateTransition(ProbeOperatorState state); + + void setState(ProbeOperatorState state) { + checkStateTransition(state); + state_ = state; + } + // Initialize spatial filter for evaluating spatial join conditions. void initializeFilter( const core::TypedExprPtr& filter, @@ -91,94 +151,14 @@ class SpatialJoinProbe : public Operator { // `buildVectors_` before it can produce output. bool getBuildData(ContinueFuture* future); - // Generates output from spatial join matches between probe and build sides, - // as well as probe mismatches (for left and full outer joins). As much as - // possible, generates outputs `outputBatchSize_` records at a time, but - // batches may be smaller in some cases - outputs follow the probe side buffer - // boundaries. + // Produce as much output as possible for the current input. RowVectorPtr generateOutput(); - // For non cross-join mode, the `output_` can be reused across multiple probe - // rows. If the input_ has remaining rows and the output_ is not fully filled, - // it returns false here. - bool readyToProduceOutput(); - - // Fill in joined output to `output_` by matching the current probeRow_ and - // successive build vectors (using getNextCrossProductBatch()). Stops when - // either all build vectors were matched for the current probeRow (returns - // true), or if the output is full (returns false). If it returns false, a - // valid vector with more than zero records will be available at `output_`; - // if it returns true, either nullptr or zero records may be placed at - // `output_`. Also if it returns true, it's the caller's responsibility to - // decide when to set `output_` size. - // - // Also updates `buildMatched_` if the build records that received a match, so - // that they can be used to implement right and full outer join semantic once - // all probe data has been processed. - bool addToOutput(); - - // Advances 'probeRow_' and resets required state information. Returns true - // if there is no more probe data to be processed in the current `input_` - // (and hence a new probe input is required). False otherwise. - bool advanceProbe(); - - // Ensures a new batch of records is available at `output_` and ready to - // receive rows. Batches have space for `outputBatchSize_`. - void prepareOutput(); - - // Evaluates the spatial joinCondition for a given build vector. This method - // sets `filterOutput_` and `decodedFilterResult_`, which will be ready to be - // used by `isSpatialJoinConditionMatch(buildRow)` below. - void evaluateSpatialJoinFilter(const RowVectorPtr& buildVector); - - // Checks if the spatial join condition matched for a particular row. - bool isSpatialJoinConditionMatch(vector_size_t i) const { - return ( - !decodedFilterResult_.isNullAt(i) && - decodedFilterResult_.valueAt(i)); + // Returns true if the input is exhausted or the output is full. + bool isOutputDone() const { + return probeRow_ >= input_->size() || outputBuilder_.isOutputFull(); } - // Generates the next batch of a cross product between probe and build. It - // should be used as the entry point, and will internally delegate to one of - // the three functions below. - // - // Output projections can be specified so that this function can be used to - // generate both filter input and actual output (in case there is no join - // filter - cross join). - RowVectorPtr getNextCrossProductBatch( - const RowVectorPtr& buildVector, - const RowTypePtr& outputType, - const std::vector& probeProjections, - const std::vector& buildProjections); - - // As a fallback, process the current probe row to as much build data as - // possible (probe row as constant, and flat copied data for build records). - RowVectorPtr genCrossProductMultipleBuildVectors( - const RowVectorPtr& buildVector, - const RowTypePtr& outputType, - const std::vector& probeProjections, - const std::vector& buildProjections); - - // Add a single record to `output_` based on buildRow from buildVector, and - // the current probeRow and probe vector (input_). Probe side projections are - // zero-copy (dictionary indices), and build side projections are marked to be - // copied using `buildCopyRanges_`; they will be copied later on by - // `copyBuildValues()`. - void addOutputRow(vector_size_t buildRow); - - // Copies the ranges from buildVector specified by `buildCopyRanges_` to - // `output_`, one projected column at a time. Clears buildCopyRanges_. - void copyBuildValues(const RowVectorPtr& buildVector); - - // Checks if it is required to add a probe mismatch row, and does it if - // needed. The caller needs to ensure there is available space in `output_` - // for the new record, which has nulled out build projections. - void checkProbeMismatchRow(); - - // Add a probe mismatch (only for left/full outer joins). The record is based - // on the current probeRow and vector (input_) and build projections are null. - void addProbeMismatchRow(); - // Called when we are done processing the current probe batch, to signal we // are ready for the next one. // @@ -186,118 +166,138 @@ class SpatialJoinProbe : public Operator { // change the operator state to signal peers. void finishProbeInput(); - // Whether we have processed all build data for the current probe row (based - // on buildIndex_'s value). - bool hasProbedAllBuildData() const { - return (buildIndex_ >= buildVectors_.value().size()); - } + // Add the output for a single probe row. This will return early if the + // output vector is full. + void addProbeRowOutput(); - // If build has a single vector, we can wrap probe and build batches into - // dictionaries and produce as many combinations of probe and build rows, - // until `numOutputRows_` is filled. - bool isSingleBuildVector() const { - return buildVectors_->size() == 1; + // Returns true if all output for the current probe row has been produced. + bool isProbeRowDone() const { + return buildIndex_ >= buildVectors_.value().size(); } - // If there are no incoming records in the build side. - bool isBuildSideEmpty() const { - return buildVectors_->empty(); + // Increment probeRow_ and reset associated fields + void advanceProbeRow() { + ++probeRow_; + probeHasMatch_ = false; + buildIndex_ = 0; + buildRow_ = 0; } - // If build has a single row, we can simply add it as a constant to probe - // batches. - bool isSingleBuildRow() const { - return isSingleBuildVector() && buildVectors_->front()->size() == 1; - } + // Add the output for a single build vector for a single probe row. This will + // return early if the output vector is full. + void addBuildVectorOutput(const RowVectorPtr& buildVector); - // TODO: Add state transition check. - void setState(ProbeOperatorState state) { - state_ = state; + // Returns true if all the rows for the current build vector have been + // processed, or the output is full. + bool isBuildVectorDone(const RowVectorPtr& buildVector) const { + return buildRow_ >= buildVector->size() || outputBuilder_.isOutputFull(); } - const core::JoinType joinType_; + // Increment buildIndex_ and reset associated fields + void advanceBuildVector() { + ++buildIndex_; + buildRow_ = 0; + } - // Output buffer members. + // Evaluates the spatial joinCondition for a given build vector. This method + // sets `filterOutput_` and `decodedFilterResult_`, which will be ready to be + // used by `isSpatialJoinConditionMatch(buildRow)` below. + void evaluateJoinFilter(const RowVectorPtr& buildVector); - // Maximum number of rows in the output batch. - const vector_size_t outputBatchSize_; + // Checks if the spatial join condition matched for a particular row. + bool isJoinConditionMatch(vector_size_t i) const { + return ( + !decodedFilterResult_.isNullAt(i) && + decodedFilterResult_.valueAt(i)); + } - // The current output batch being populated. - RowVectorPtr output_; + // Generates the next batch of a cross product between probe and build using + // the supplied projections. It uses the current probe row as constant, and + // flat copied data for build records. + RowVectorPtr getNextJoinBatch( + const RowVectorPtr& buildVector, + const RowTypePtr& outputType, + const std::vector& probeProjections, + const std::vector& buildProjections) const; - // Number of output rows in the current output batch. - vector_size_t numOutputRows_{0}; + ///////// + // SETUP + // Variables set during operator setup that are used during execution. + // These should not be modified after the operator is initialized. - // Dictionary indices for probe columns used to generate cross-product. - BufferPtr probeIndices_; + const core::JoinType joinType_; - // Dictionary indices for probe columns for output vector. - BufferPtr probeOutputIndices_; - vector_size_t* rawProbeOutputIndices_{}; + // Maximum number of rows in the output batch. + const vector_size_t outputBatchSize_; - // Dictionary indices for build columns. - BufferPtr buildIndices_; + // Join metadata and state. + std::shared_ptr joinNode_; // Spatial join condition expression. - // Must not be null std::unique_ptr joinCondition_; // Input type for the spatial join condition expression. RowTypePtr filterInputType_; + // List of output projections from the build side. Note that the list of + // projections from the probe side is available at `identityProjections_`. + std::vector buildProjections_; + + // Projections needed as input to the filter to evaluation spatial join filter + // conditions. Note that if this is a cross-join, filter projections are the + // same as output projections. + std::vector filterProbeProjections_; + std::vector filterBuildProjections_; + + // Stores the data for build vectors (right side of the join). + std::optional> buildVectors_; + + ////////////////// + // OPERATOR STATE + // Variables used to track the general operator state during exection. + // These will change throughout setup and execution. + + ProbeOperatorState state_{ProbeOperatorState::kWaitForBuild}; + ContinueFuture future_{ContinueFuture::makeEmpty()}; + + // The information needed to produce an output RowVectorPtr. It is stored + // for all execution, but is reset on each output batch. + SpatialJoinOutputBuilder outputBuilder_; + + // Count of output batches produced (1-indexed). Primarily for debugging. + size_t outputCount_{0}; + // Spatial join condition evaluation state that need to persisted across the // generation of successive output buffers. SelectivityVector filterInputRows_; VectorPtr filterOutput_; DecodedVector decodedFilterResult_; - // Join metadata and state. - std::shared_ptr joinNode_; - - ProbeOperatorState state_{ProbeOperatorState::kWaitForBuild}; - ContinueFuture future_{ContinueFuture::makeEmpty()}; + /////////////// + // PROBE STATE + // Variables used to track the probe-side state state during exection. + // These will change throughout setup and execution. - // Probe side state. + // Count of probe batches added (1-indexed). Primarily for debugging. + size_t probeCount_{0}; // Probe row being currently processed (related to `input_`). vector_size_t probeRow_{0}; - // Whether the current probeRow_ has produces a match. Used for left and full - // outer joins. - bool probeRowHasMatch_{false}; - - // Indicate if the probe side has empty input or not. For the last probe, - // this indicates if all the probe sides are empty or not. This flag is used - // for mismatched output producing. - bool probeSideEmpty_{true}; + // Whether the current probeRow_ has found a match. Needed for left join. + bool probeHasMatch_{false}; - // Build side state. - - // Stores the data for build vectors (right side of the join). - std::optional> buildVectors_; + /////////////// + // BUILD STATE + // Variables used to track the build-side state state during exection. + // These will change throughout setup and execution. // Index into `buildVectors_` for the build vector being currently processed. size_t buildIndex_{0}; // Row being currently processed from `buildVectors_[buildIndex_]`. vector_size_t buildRow_{0}; - - // Stores the ranges of build values to be copied to the output vector (we - // batch them and copy once, instead of copying them row-by-row). - std::vector buildCopyRanges_; - - // List of output projections from the build side. Note that the list of - // projections from the probe side is available at `identityProjections_`. - std::vector buildProjections_; - - // Projections needed as input to the filter to evaluation spatial join filter - // conditions. Note that if this is a cross-join, filter projections are the - // same as output projections. - std::vector filterProbeProjections_; - std::vector filterBuildProjections_; - - BufferPtr buildOutMapping_; }; } // namespace facebook::velox::exec diff --git a/velox/exec/Spill.cpp b/velox/exec/Spill.cpp index 376d06502a3c..55acc12fa0c1 100644 --- a/velox/exec/Spill.cpp +++ b/velox/exec/Spill.cpp @@ -436,6 +436,8 @@ std::unique_ptr ConcatFilesSpillBatchStream::create( } bool ConcatFilesSpillBatchStream::nextBatch(RowVectorPtr& batch) { + TestValue::adjust( + "facebook::velox::exec::ConcatFilesSpillBatchStream::nextBatch", nullptr); VELOX_CHECK_NULL(batch); VELOX_CHECK(!atEnd_); for (; fileIndex_ < spillFiles_.size(); ++fileIndex_) { diff --git a/velox/exec/Spill.h b/velox/exec/Spill.h index 18fd2cbfcb17..582ca9f331bc 100644 --- a/velox/exec/Spill.h +++ b/velox/exec/Spill.h @@ -21,11 +21,11 @@ #include #include "velox/common/base/SpillConfig.h" #include "velox/common/base/SpillStats.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/common/compression/Compression.h" #include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" #include "velox/exec/SpillFile.h" -#include "velox/exec/TreeOfLosers.h" #include "velox/exec/UnorderedStreamReader.h" #include "velox/exec/VectorHasher.h" #include "velox/vector/ComplexVector.h" diff --git a/velox/exec/SpillFile.cpp b/velox/exec/SpillFile.cpp index a5ad7d53d7e9..b04af2fc8133 100644 --- a/velox/exec/SpillFile.cpp +++ b/velox/exec/SpillFile.cpp @@ -16,8 +16,7 @@ #include "velox/exec/SpillFile.h" #include "velox/common/base/RuntimeMetrics.h" -#include "velox/common/file/FileSystems.h" -#include "velox/vector/VectorStream.h" +#include "velox/serializers/SerializedPageFile.h" namespace facebook::velox::exec { namespace { @@ -29,49 +28,6 @@ namespace { static const bool kDefaultUseLosslessTimestamp = true; } // namespace -std::unique_ptr SpillWriteFile::create( - uint32_t id, - const std::string& pathPrefix, - const std::string& fileCreateConfig) { - return std::unique_ptr( - new SpillWriteFile(id, pathPrefix, fileCreateConfig)); -} - -SpillWriteFile::SpillWriteFile( - uint32_t id, - const std::string& pathPrefix, - const std::string& fileCreateConfig) - : id_(id), path_(fmt::format("{}-{}", pathPrefix, ordinalCounter_++)) { - auto fs = filesystems::getFileSystem(path_, nullptr); - file_ = fs->openFileForWrite( - path_, - filesystems::FileOptions{ - {{filesystems::FileOptions::kFileCreateConfig.toString(), - fileCreateConfig}}, - nullptr, - std::nullopt}); -} - -void SpillWriteFile::finish() { - VELOX_CHECK_NOT_NULL(file_); - size_ = file_->size(); - file_->close(); - file_ = nullptr; -} - -uint64_t SpillWriteFile::size() const { - if (file_ != nullptr) { - return file_->size(); - } - return size_; -} - -uint64_t SpillWriteFile::write(std::unique_ptr iobuf) { - auto writtenBytes = iobuf->computeChainDataLength(); - file_->append(std::move(iobuf)); - return writtenBytes; -} - SpillWriter::SpillWriter( const RowTypePtr& type, const std::vector& sortingKeys, @@ -83,108 +39,23 @@ SpillWriter::SpillWriter( common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, memory::MemoryPool* pool, folly::Synchronized* stats) - : type_(type), + : serializer::SerializedPageFileWriter( + pathPrefix, + targetFileSize, + writeBufferSize, + fileCreateConfig, + std::make_unique< + serializer::presto::PrestoVectorSerde::PrestoOptions>( + kDefaultUseLosslessTimestamp, + compressionKind, + 0.8, + /*_nullsFirst=*/true), + getNamedVectorSerde(VectorSerde::Kind::kPresto), + pool), + type_(type), sortingKeys_(sortingKeys), - compressionKind_(compressionKind), - pathPrefix_(pathPrefix), - targetFileSize_(targetFileSize), - writeBufferSize_(writeBufferSize), - fileCreateConfig_(fileCreateConfig), - updateAndCheckSpillLimitCb_(updateAndCheckSpillLimitCb), - pool_(pool), - serde_(getNamedVectorSerde(VectorSerde::Kind::kPresto)), - stats_(stats) {} - -SpillWriteFile* SpillWriter::ensureFile() { - if ((currentFile_ != nullptr) && (currentFile_->size() > targetFileSize_)) { - closeFile(); - } - if (currentFile_ == nullptr) { - currentFile_ = SpillWriteFile::create( - nextFileId_++, - fmt::format("{}-{}", pathPrefix_, finishedFiles_.size()), - fileCreateConfig_); - } - return currentFile_.get(); -} - -void SpillWriter::closeFile() { - if (currentFile_ == nullptr) { - return; - } - currentFile_->finish(); - updateSpilledFileStats(currentFile_->size()); - finishedFiles_.push_back(SpillFileInfo{ - .id = currentFile_->id(), - .type = type_, - .path = currentFile_->path(), - .size = currentFile_->size(), - .sortingKeys = sortingKeys_, - .compressionKind = compressionKind_}); - currentFile_.reset(); -} - -size_t SpillWriter::numFinishedFiles() const { - return finishedFiles_.size(); -} - -uint64_t SpillWriter::flush() { - if (batch_ == nullptr) { - return 0; - } - - auto* file = ensureFile(); - VELOX_CHECK_NOT_NULL(file); - - IOBufOutputStream out( - *pool_, nullptr, std::max(64 * 1024, batch_->size())); - uint64_t flushTimeNs{0}; - { - NanosecondTimer timer(&flushTimeNs); - batch_->flush(&out); - } - batch_.reset(); - - uint64_t writeTimeNs{0}; - uint64_t writtenBytes{0}; - auto iobuf = out.getIOBuf(); - { - NanosecondTimer timer(&writeTimeNs); - writtenBytes = file->write(std::move(iobuf)); - } - updateWriteStats(writtenBytes, flushTimeNs, writeTimeNs); - updateAndCheckSpillLimitCb_(writtenBytes); - return writtenBytes; -} - -uint64_t SpillWriter::write( - const RowVectorPtr& rows, - const folly::Range& indices) { - checkNotFinished(); - - uint64_t timeNs{0}; - { - NanosecondTimer timer(&timeNs); - if (batch_ == nullptr) { - serializer::presto::PrestoVectorSerde::PrestoOptions options = { - kDefaultUseLosslessTimestamp, - compressionKind_, - 0.8, - /*_nullsFirst=*/true}; - batch_ = std::make_unique(pool_, serde_); - batch_->createStreamTree( - std::static_pointer_cast(rows->type()), - 1'000, - &options); - } - batch_->append(rows, indices); - } - updateAppendStats(rows->size(), timeNs); - if (batch_->size() < writeBufferSize_) { - return 0; - } - return flush(); -} + stats_(stats), + updateAndCheckLimitCb_(updateAndCheckSpillLimitCb) {} void SpillWriter::updateAppendStats( uint64_t numRows, @@ -206,34 +77,40 @@ void SpillWriter::updateWriteStats( ++statsLocked->spillWrites; common::updateGlobalSpillWriteStats( spilledBytes, flushTimeNs, fileWriteTimeNs); + updateAndCheckLimitCb_(spilledBytes); } -void SpillWriter::updateSpilledFileStats(uint64_t fileSize) { +void SpillWriter::updateFileStats( + const serializer::SerializedPageFile::FileInfo& file) { ++stats_->wlock()->spilledFiles; addThreadLocalRuntimeStat( - "spillFileSize", RuntimeCounter(fileSize, RuntimeCounter::Unit::kBytes)); + "spillFileSize", RuntimeCounter(file.size, RuntimeCounter::Unit::kBytes)); common::incrementGlobalSpilledFiles(); } -void SpillWriter::finishFile() { - checkNotFinished(); - flush(); - closeFile(); - VELOX_CHECK_NULL(currentFile_); -} - SpillFiles SpillWriter::finish() { - checkNotFinished(); - auto finishGuard = folly::makeGuard([this]() { finished_ = true; }); - - finishFile(); - return std::move(finishedFiles_); + const auto serializedPageFiles = + serializer::SerializedPageFileWriter::finish(); + SpillFiles spillFiles; + spillFiles.reserve(serializedPageFiles.size()); + for (const auto& fileInfo : serializedPageFiles) { + spillFiles.push_back(SpillFileInfo{ + .id = fileInfo.id, + .type = type_, + .path = fileInfo.path, + .size = fileInfo.size, + .sortingKeys = sortingKeys_, + .compressionKind = serdeOptions_->compressionKind}); + } + return spillFiles; } std::vector SpillWriter::testingSpilledFilePaths() const { checkNotFinished(); std::vector spilledFilePaths; + spilledFilePaths.reserve( + finishedFiles_.size() + (currentFile_ != nullptr ? 1 : 0)); for (auto& file : finishedFiles_) { spilledFilePaths.push_back(file.path); } @@ -283,44 +160,25 @@ SpillReadFile::SpillReadFile( common::CompressionKind compressionKind, memory::MemoryPool* pool, folly::Synchronized* stats) - : id_(id), + : serializer::SerializedPageFileReader( + path, + bufferSize, + type, + getNamedVectorSerde(VectorSerde::Kind::kPresto), + std::make_unique< + serializer::presto::PrestoVectorSerde::PrestoOptions>( + kDefaultUseLosslessTimestamp, + compressionKind, + 0.8, + /*_nullsFirst=*/true), + pool), + id_(id), path_(path), size_(size), - type_(type), sortingKeys_(sortingKeys), - compressionKind_(compressionKind), - readOptions_{ - kDefaultUseLosslessTimestamp, - compressionKind_, - 0.8, - /*_nullsFirst=*/true}, - pool_(pool), - serde_(getNamedVectorSerde(VectorSerde::Kind::kPresto)), - stats_(stats) { - auto fs = filesystems::getFileSystem(path_, nullptr); - auto file = fs->openFileForRead(path_); - input_ = std::make_unique( - std::move(file), bufferSize, pool_); -} - -bool SpillReadFile::nextBatch(RowVectorPtr& rowVector) { - if (input_->atEnd()) { - recordSpillStats(); - return false; - } - - uint64_t timeNs{0}; - { - NanosecondTimer timer{&timeNs}; - VectorStreamGroup::read( - input_.get(), pool_, type_, serde_, &rowVector, &readOptions_); - } - stats_->wlock()->spillDeserializationTimeNanos += timeNs; - common::updateGlobalSpillDeserializationTimeNs(timeNs); - return true; -} + stats_(stats) {} -void SpillReadFile::recordSpillStats() { +void SpillReadFile::updateFinalStats() { VELOX_CHECK(input_->atEnd()); const auto readStats = input_->stats(); common::updateGlobalSpillReadStats( @@ -329,5 +187,11 @@ void SpillReadFile::recordSpillStats() { lockedSpillStats->spillReads += readStats.numReads; lockedSpillStats->spillReadTimeNanos += readStats.readTimeNs; lockedSpillStats->spillReadBytes += readStats.readBytes; -} +}; + +void SpillReadFile::updateSerializationTimeStats(uint64_t timeNs) { + stats_->wlock()->spillDeserializationTimeNanos += timeNs; + common::updateGlobalSpillDeserializationTimeNs(timeNs); +}; + } // namespace facebook::velox::exec diff --git a/velox/exec/SpillFile.h b/velox/exec/SpillFile.h index 55547eec0a53..e959b86840f1 100644 --- a/velox/exec/SpillFile.h +++ b/velox/exec/SpillFile.h @@ -17,14 +17,16 @@ #pragma once #include +#include #include "velox/common/base/SpillConfig.h" #include "velox/common/base/SpillStats.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/common/compression/Compression.h" #include "velox/common/file/File.h" #include "velox/common/file/FileInputStream.h" -#include "velox/exec/TreeOfLosers.h" #include "velox/serializers/PrestoSerializer.h" +#include "velox/serializers/SerializedPageFile.h" #include "velox/vector/ComplexVector.h" #include "velox/vector/DecodedVector.h" #include "velox/vector/VectorStream.h" @@ -32,53 +34,6 @@ namespace facebook::velox::exec { using SpillSortKey = std::pair; -/// Represents a spill file for writing the serialized spilled data into a disk -/// file. -class SpillWriteFile { - public: - static std::unique_ptr create( - uint32_t id, - const std::string& pathPrefix, - const std::string& fileCreateConfig); - - uint32_t id() const { - return id_; - } - - /// Returns the file size in bytes. - uint64_t size() const; - - const std::string& path() const { - return path_; - } - - uint64_t write(std::unique_ptr iobuf); - - WriteFile* file() { - return file_.get(); - } - - /// Finishes writing and flushes any unwritten data. - void finish(); - - private: - static inline std::atomic ordinalCounter_{0}; - - SpillWriteFile( - uint32_t id, - const std::string& pathPrefix, - const std::string& fileCreateConfig); - - // The spill file id which is monotonically increasing and unique for each - // associated spill partition. - const uint32_t id_; - const std::string path_; - - std::unique_ptr file_; - // Byte size of the backing file. Set when finishing writing. - uint64_t size_{0}; -}; - /// Records info of a finished spill file which is used for read. struct SpillFileInfo { uint32_t id; @@ -95,7 +50,7 @@ using SpillFiles = std::vector; /// Used to write the spilled data to a sequence of files for one partition. If /// data is sorted, each file is sorted. The globally sorted order is produced /// by merging the constituent files. -class SpillWriter { +class SpillWriter : public serializer::SerializedPageFileWriter { public: /// 'type' is a RowType describing the content. 'numSortKeys' is the number /// of leading columns on which the data is sorted. 'path' is a file path @@ -120,25 +75,9 @@ class SpillWriter { memory::MemoryPool* pool, folly::Synchronized* stats); - /// Adds 'rows' for the positions in 'indices' into 'this'. The indices - /// must produce a view where the rows are sorted if sorting is desired. - /// Consecutive calls must have sorted data so that the first row of the - /// next call is not less than the last row of the previous call. - /// Returns the size to write. - uint64_t write( - const RowVectorPtr& rows, - const folly::Range& indices); - - /// Closes the current output file if any. Subsequent calls to write will - /// start a new one. - void finishFile(); - - /// Returns the number of current finished files. - size_t numFinishedFiles() const; - /// Finishes this file writer and returns the written spill files info. /// - /// NOTE: we don't allow write to a spill writer after t + /// NOTE: we don't allow write to a spill writer after finish SpillFiles finish(); std::vector testingSpilledFilePaths() const; @@ -146,55 +85,29 @@ class SpillWriter { std::vector testingSpilledFileIds() const; private: - FOLLY_ALWAYS_INLINE void checkNotFinished() const { - VELOX_CHECK(!finished_, "SpillWriter has finished"); - } - - // Returns an open spill file for write. If there is no open spill file, then - // the function creates a new one. If the current open spill file exceeds the - // target file size limit, then it first closes the current one and then - // creates a new one. 'currentFile_' points to the current open spill file. - SpillWriteFile* ensureFile(); - - // Closes the current open spill file pointed by 'currentFile_'. - void closeFile(); - - // Writes data from 'batch_' to the current output file. Returns the actual - // written size. - uint64_t flush(); - // Invoked to increment the number of spilled files and the file size. - void updateSpilledFileStats(uint64_t fileSize); + void updateFileStats( + const serializer::SerializedPageFile::FileInfo& fileInfo) override; // Invoked to update the number of spilled rows. - void updateAppendStats(uint64_t numRows, uint64_t serializationTimeUs); + void updateAppendStats(uint64_t numRows, uint64_t serializationTimeUs) + override; // Invoked to update the disk write stats. void updateWriteStats( uint64_t spilledBytes, uint64_t flushTimeUs, - uint64_t writeTimeUs); + uint64_t writeTimeUs) override; const RowTypePtr type_; + const std::vector sortingKeys_; - const common::CompressionKind compressionKind_; - const std::string pathPrefix_; - const uint64_t targetFileSize_; - const uint64_t writeBufferSize_; - const std::string fileCreateConfig_; - // Updates the aggregated spill bytes of this query, and throws if exceeds - // the max spill bytes limit. - const common::UpdateAndCheckSpillLimitCB updateAndCheckSpillLimitCb_; - memory::MemoryPool* const pool_; - VectorSerde* const serde_; folly::Synchronized* const stats_; - bool finished_{false}; - uint32_t nextFileId_{0}; - std::unique_ptr batch_; - std::unique_ptr currentFile_; - SpillFiles finishedFiles_; + // Updates the aggregated bytes of this query, and throws if exceeds + // the max bytes limit. + const common::UpdateAndCheckSpillLimitCB updateAndCheckLimitCb_; }; /// Represents a spill file for read which turns the serialized spilled data @@ -204,7 +117,7 @@ class SpillWriter { /// needs to remove the unused spill files at some point later. For example, a /// query Task deletes all the generated spill files in one operation using /// rmdir() call. -class SpillReadFile { +class SpillReadFile : public serializer::SerializedPageFileReader { public: static std::unique_ptr create( const SpillFileInfo& fileInfo, @@ -220,8 +133,6 @@ class SpillReadFile { return sortingKeys_; } - bool nextBatch(RowVectorPtr& rowVector); - /// Returns the file size in bytes. uint64_t size() const { return size_; @@ -243,24 +154,23 @@ class SpillReadFile { memory::MemoryPool* pool, folly::Synchronized* stats); - // Invoked to record spill read stats at the end of read input. - void recordSpillStats(); + // Records spill read stats at the end of read input. + void updateFinalStats() override; + + void updateSerializationTimeStats(uint64_t timeNs) override; // The spill file id which is monotonically increasing and unique for each // associated spill partition. const uint32_t id_; + const std::string path_; + // The file size in bytes. const uint64_t size_; - // The data type of spilled data. - const RowTypePtr type_; + const std::vector sortingKeys_; - const common::CompressionKind compressionKind_; - const serializer::presto::PrestoVectorSerde::PrestoOptions readOptions_; - memory::MemoryPool* const pool_; - VectorSerde* const serde_; - folly::Synchronized* const stats_; - std::unique_ptr input_; + folly::Synchronized* const stats_; }; + } // namespace facebook::velox::exec diff --git a/velox/exec/TableScan.cpp b/velox/exec/TableScan.cpp index 9157e876d5f3..6c19c1b89510 100644 --- a/velox/exec/TableScan.cpp +++ b/velox/exec/TableScan.cpp @@ -180,15 +180,16 @@ RowVectorPtr TableScan::getOutput() { } continue; } - const auto estimatedRowSize = dataSource_->estimatedRowSize(); - readBatchSize_ = - estimatedRowSize == connector::DataSource::kUnknownRowSize - ? outputBatchRows() - : outputBatchRows(estimatedRowSize); } VELOX_CHECK(!needNewSplit_); VELOX_CHECK(!hasDrained()); + const auto estimatedRowSize = dataSource_->estimatedRowSize(); + // TODO: Expose this to operator stats. + VLOG(1) << "estimatedRowSize = " << estimatedRowSize; + readBatchSize_ = estimatedRowSize == connector::DataSource::kUnknownRowSize + ? outputBatchRows() + : outputBatchRows(estimatedRowSize); int32_t readBatchSize = readBatchSize_; if (maxFilteringRatio_ > 0) { readBatchSize = std::min( @@ -303,15 +304,15 @@ bool TableScan::getSplit() { if (!split.hasConnectorSplit()) { noMoreSplits_ = true; if (dataSource_) { - const auto connectorStats = dataSource_->runtimeStats(); + const auto connectorStats = dataSource_->getRuntimeStats(); auto lockedStats = stats_.wlock(); - for (const auto& [name, counter] : connectorStats) { + for (const auto& [name, metric] : connectorStats) { if (FOLLY_UNLIKELY(lockedStats->runtimeStats.count(name) == 0)) { - lockedStats->runtimeStats.emplace(name, RuntimeMetric(counter.unit)); + lockedStats->runtimeStats.emplace(name, RuntimeMetric(metric.unit)); } else { - VELOX_CHECK_EQ(lockedStats->runtimeStats.at(name).unit, counter.unit); + VELOX_CHECK_EQ(lockedStats->runtimeStats.at(name).unit, metric.unit); } - lockedStats->runtimeStats.at(name).addValue(counter.value); + lockedStats->runtimeStats.at(name).merge(metric); } } return false; diff --git a/velox/exec/Task.cpp b/velox/exec/Task.cpp index eab1bf7d59c6..437bdcf05875 100644 --- a/velox/exec/Task.cpp +++ b/velox/exec/Task.cpp @@ -320,8 +320,8 @@ std::shared_ptr Task::create( ExecutionMode mode, Consumer consumer, int32_t memoryArbitrationPriority, + std::optional spillDiskOpts, std::function onError) { - VELOX_CHECK_NOT_NULL(planFragment.planNode); return Task::create( taskId, std::move(planFragment), @@ -331,6 +331,7 @@ std::shared_ptr Task::create( (consumer ? [c = std::move(consumer)]() { return c; } : ConsumerSupplier{}), memoryArbitrationPriority, + std::move(spillDiskOpts), std::move(onError)); } @@ -343,6 +344,7 @@ std::shared_ptr Task::create( ExecutionMode mode, ConsumerSupplier consumerSupplier, int32_t memoryArbitrationPriority, + std::optional spillDiskOpts, std::function onError) { VELOX_CHECK_NOT_NULL(planFragment.planNode); auto task = std::shared_ptr(new Task( @@ -354,7 +356,7 @@ std::shared_ptr Task::create( std::move(consumerSupplier), memoryArbitrationPriority, std::move(onError))); - task->initTaskPool(); + task->init(std::move(spillDiskOpts)); task->addToTaskList(); return task; } @@ -469,6 +471,72 @@ void Task::ensureBarrierSupport() const { firstNodeNotSupportingBarrier_->name()); } +void Task::init(std::optional&& spillDiskOpts) { + VELOX_CHECK(driverFactories_.empty()); + initTaskPool(); + + setSpillDiskConfig(std::move(spillDiskOpts)); + + if (mode_ != Task::ExecutionMode::kSerial) { + return; + } + + // Create drivers. + VELOX_CHECK_NULL( + consumerSupplier_, + "Serial execution mode doesn't support delivering results to a " + "callback"); + + taskStats_.executionStartTimeMs = getCurrentTimeMs(); + LocalPlanner::plan( + planFragment_, nullptr, &driverFactories_, queryCtx_->queryConfig(), 1); + exchangeClients_.resize(driverFactories_.size()); + + // In Task::next() we always assume ungrouped execution. + for (const auto& factory : driverFactories_) { + VELOX_CHECK(factory->supportsSerialExecution()); + numDriversUngrouped_ += factory->numDrivers; + numTotalDrivers_ += factory->numTotalDrivers; + taskStats_.pipelineStats.emplace_back( + factory->inputDriver, factory->outputDriver); + } + + // Create drivers. + createSplitGroupStateLocked(kUngroupedGroupId); + std::vector> drivers = + createDriversLocked(kUngroupedGroupId); + if (pool_->reservedBytes() != 0) { + VELOX_FAIL( + "Unexpected memory pool allocations during task[{}] driver initialization: {}", + taskId_, + pool_->treeMemoryUsage()); + } + + drivers_ = std::move(drivers); + driverBlockingStates_.reserve(drivers_.size()); + for (auto i = 0; i < drivers_.size(); ++i) { + driverBlockingStates_.emplace_back( + std::make_unique(drivers_[i].get())); + } +} + +void Task::setSpillDiskConfig( + std::optional&& spillDiskOpts) { + if (!spillDiskOpts.has_value()) { + return; + } + VELOX_CHECK( + !spillDiskOpts->spillDirPath.empty(), "Spill directory can't be empty"); + VELOX_CHECK( + spillDiskOpts->spillDirCreated || spillDiskOpts->spillDirCreateCb); + VELOX_CHECK_NULL(spillDirectoryCallback_); + VELOX_CHECK(!spillDirectoryCreated_); + VELOX_CHECK(spillDirectory_.empty()); + spillDirectory_ = std::move(spillDiskOpts->spillDirPath); + spillDirectoryCreated_ = spillDiskOpts->spillDirCreated; + spillDirectoryCallback_ = std::move(spillDiskOpts->spillDirCreateCb); +} + Task::TaskList& Task::taskList() { static TaskList taskList; return taskList; @@ -547,7 +615,7 @@ bool Task::allNodesReceivedNoMoreSplitsMessageLocked() const { const std::string& Task::getOrCreateSpillDirectory() { VELOX_CHECK( !spillDirectory_.empty() || spillDirectoryCallback_, - "Spill directory or spill directory callback must be set "); + "Spill directory or spill directory callback must be set"); if (spillDirectoryCreated_) { return spillDirectory_; } @@ -769,48 +837,8 @@ RowVectorPtr Task::next(ContinueFuture* future) { } } - // On first call, create the drivers. - if (driverFactories_.empty()) { - VELOX_CHECK_NULL( - consumerSupplier_, - "Serial execution mode doesn't support delivering results to a " - "callback"); - - taskStats_.executionStartTimeMs = getCurrentTimeMs(); - LocalPlanner::plan( - planFragment_, nullptr, &driverFactories_, queryCtx_->queryConfig(), 1); - exchangeClients_.resize(driverFactories_.size()); - - // In Task::next() we always assume ungrouped execution. - for (const auto& factory : driverFactories_) { - VELOX_CHECK(factory->supportsSerialExecution()); - numDriversUngrouped_ += factory->numDrivers; - numTotalDrivers_ += factory->numTotalDrivers; - taskStats_.pipelineStats.emplace_back( - factory->inputDriver, factory->outputDriver); - } - - // Create drivers. - createSplitGroupStateLocked(kUngroupedGroupId); - std::vector> drivers = - createDriversLocked(kUngroupedGroupId); - if (pool_->reservedBytes() != 0) { - VELOX_FAIL( - "Unexpected memory pool allocations during task[{}] driver initialization: {}", - taskId_, - pool_->treeMemoryUsage()); - } - - drivers_ = std::move(drivers); - driverBlockingStates_.reserve(drivers_.size()); - for (auto i = 0; i < drivers_.size(); ++i) { - driverBlockingStates_.emplace_back( - std::make_unique(drivers_[i].get())); - } - if (underBarrier()) { - startDriverBarriersLocked(); - } - } + VELOX_CHECK_EQ( + state_, TaskState::kRunning, "Task has already finished processing."); // Run drivers one at a time. If a driver blocks, continue running the other // drivers. Running other drivers is expected to unblock some or all blocked diff --git a/velox/exec/Task.h b/velox/exec/Task.h index 01b906df087f..33876721cc28 100644 --- a/velox/exec/Task.h +++ b/velox/exec/Task.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "velox/common/base/SkewedPartitionBalancer.h" #include "velox/common/base/TraceConfig.h" #include "velox/core/PlanFragment.h" @@ -66,11 +68,16 @@ class Task : public std::enable_shared_from_this { /// @param consumer Optional factory function to get callbacks to pass the /// results of the execution. In a parallel execution mode, results from each /// thread are passed on to a separate consumer. + /// @param memoryArbitrationPriority Priority used by the memory arbitrator + /// to determine which task should have its memory reclaimed first when the + /// system is under memory pressure. Higher values indicate higher priority + /// (lower likelihood of being reclaimed). Default is 0. + /// @param spillDiskOpts Optional configuration for spill disk storage. When + /// provided, allows operators to spill intermediate data to disk during + /// execution when memory pressure is high. Includes spill directory path + /// and callback options. Default is std::nullopt (no spilling). /// @param onError Optional callback to receive an exception if task /// execution fails. - /// @param memoryArbitrationPriority Optional priority on task that, in a - /// multi task system, is used for memory arbitration to decide the order of - /// reclaiming. static std::shared_ptr create( const std::string& taskId, core::PlanFragment planFragment, @@ -79,6 +86,7 @@ class Task : public std::enable_shared_from_this { ExecutionMode mode, Consumer consumer = nullptr, int32_t memoryArbitrationPriority = 0, + std::optional spillDiskOpts = std::nullopt, std::function onError = nullptr); static std::shared_ptr create( @@ -89,6 +97,7 @@ class Task : public std::enable_shared_from_this { ExecutionMode mode, ConsumerSupplier consumerSupplier, int32_t memoryArbitrationPriority = 0, + std::optional spillDiskOpts = std::nullopt, std::function onError = nullptr); /// Convenience function for shortening a Presto taskId. To be used @@ -97,22 +106,6 @@ class Task : public std::enable_shared_from_this { ~Task(); - /// Specify directory to which data will be spilled if spilling is enabled and - /// required. Set 'alreadyCreated' to true if the directory has already been - /// created by the caller. - void setSpillDirectory( - const std::string& spillDirectory, - bool alreadyCreated = true) { - spillDirectory_ = spillDirectory; - spillDirectoryCreated_ = alreadyCreated; - } - - void setCreateSpillDirectoryCb( - std::function spillDirectoryCallback) { - VELOX_CHECK_NULL(spillDirectoryCallback_); - spillDirectoryCallback_ = std::move(spillDirectoryCallback); - } - /// Returns human-friendly representation of the plan augmented with runtime /// statistics. The implementation invokes exec::printPlanWithStats(). /// @@ -821,6 +814,13 @@ class Task : public std::enable_shared_from_this { int32_t memoryArbitrationPriority = 0, std::function onError = nullptr); + // Invoked to do post-create initialization. + void init(std::optional&& spillDiskOpts); + + // Invoked to initialize the spill storage config for this task. + void setSpillDiskConfig( + std::optional&& spillDiskOpts); + // Invoked to add this to the system-wide running task list on task creation. void addToTaskList(); @@ -1377,7 +1377,7 @@ class Task : public std::enable_shared_from_this { // The promises for the futures returned to callers of requestBarrier(). std::vector barrierFinishPromises_; - std::atomic toYield_ = 0; + std::atomic_int32_t toYield_ = 0; int32_t numThreads_ = 0; // Microsecond real time when 'this' last went from no threads to // one thread running. Used to decide if continuous run should be diff --git a/velox/exec/TaskTraceReader.cpp b/velox/exec/TaskTraceReader.cpp index 59647bd5305d..6d05b66bda9a 100644 --- a/velox/exec/TaskTraceReader.cpp +++ b/velox/exec/TaskTraceReader.cpp @@ -66,17 +66,17 @@ core::PlanNodePtr TaskTraceMetadataReader::queryPlan() const { } std::string TaskTraceMetadataReader::nodeName(const std::string& nodeId) const { - const auto* traceNode = core::PlanNode::findFirstNode( - tracePlanNode_.get(), - [&nodeId](const core::PlanNode* node) { return node->id() == nodeId; }); + LOG(ERROR) << "node id " << nodeId << " trace plan node " + << tracePlanNode_->toString(true, true); + const auto* traceNode = + core::PlanNode::findNodeById(tracePlanNode_.get(), nodeId); return std::string(traceNode->name()); } std::optional TaskTraceMetadataReader::connectorId( const std::string& nodeId) const { - const auto* traceNode = core::PlanNode::findFirstNode( - tracePlanNode_.get(), - [&nodeId](const core::PlanNode* node) { return node->id() == nodeId; }); + const auto* traceNode = + core::PlanNode::findNodeById(tracePlanNode_.get(), nodeId); if (const auto* indexLookupJoinNode = dynamic_cast(traceNode)) { @@ -84,14 +84,16 @@ std::optional TaskTraceMetadataReader::connectorId( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(); VELOX_CHECK(!indexLookupConnectorId.empty()); return indexLookupConnectorId; - } else if ( - const auto* tableScanNode = + } + + if (const auto* tableScanNode = dynamic_cast(traceNode)) { VELOX_CHECK_NOT_NULL(tableScanNode); const auto connectorId = tableScanNode->tableHandle()->connectorId(); VELOX_CHECK(!connectorId.empty()); return connectorId; } + return std::nullopt; } } // namespace facebook::velox::exec::trace diff --git a/velox/exec/TopNRowNumber.cpp b/velox/exec/TopNRowNumber.cpp index 5bd614ffa1ed..9f0dafb202b3 100644 --- a/velox/exec/TopNRowNumber.cpp +++ b/velox/exec/TopNRowNumber.cpp @@ -213,6 +213,10 @@ void TopNRowNumber::addInput(RowVectorPtr input) { // Otherwise, check if row should replace an existing row or be discarded. processInputRowLoop(numInput); + // It is determined that the TopNRowNumber (as a partial) is not rejecting + // enough input rows to make the duplicate detection worthwhile. Hence, + // abandon the processing at this partial TopN and let the final TopN do + // the processing. if (abandonPartialEarly()) { abandonedPartial_ = true; addRuntimeStat("abandonedPartial", RuntimeCounter(1)); @@ -356,9 +360,14 @@ void TopNRowNumber::updateEstimatedOutputRowSize() { } TopNRowNumber::TopRows* TopNRowNumber::nextPartition() { + auto setNextRank = [&](const TopRows& partition) { + nextRank_ = partition.topRank; + }; + if (!table_) { if (!outputPartitionNumber_) { outputPartitionNumber_ = 0; + setNextRank(*singlePartition_); return singlePartition_.get(); } return nullptr; @@ -384,7 +393,15 @@ TopNRowNumber::TopRows* TopNRowNumber::nextPartition() { } } - return &partitionAt(partitions_[outputPartitionNumber_.value()]); + auto partition = &partitionAt(partitions_[outputPartitionNumber_.value()]); + setNextRank(*partition); + return partition; +} + +void TopNRowNumber::computeNextRankInMemory( + const TopRows& /*partition*/, + vector_size_t /*rowIndex*/) { + nextRank_ -= 1; } void TopNRowNumber::appendPartitionRows( @@ -394,14 +411,16 @@ void TopNRowNumber::appendPartitionRows( FlatVector* rowNumbers) { // The partition.rows priority queue pops rows in order of reverse // row numbers. - auto rowNumber = partition.rows.size(); for (auto i = 0; i < numRows; ++i) { const auto index = outputOffset + i; if (rowNumbers) { - rowNumbers->set(index, rowNumber--); + rowNumbers->set(index, nextRank_); } outputRows_[index] = partition.rows.top(); partition.rows.pop(); + if (!partition.rows.empty()) { + computeNextRankInMemory(partition, index); + } } } @@ -417,7 +436,7 @@ RowVectorPtr TopNRowNumber::getOutput() { return output; } - // We may have input accumulated in 'data_'. + // There could be older rows accumulated in 'data_'. if (data_->numRows() > 0) { return getOutputFromMemory(); } @@ -426,6 +445,7 @@ RowVectorPtr TopNRowNumber::getOutput() { finished_ = true; } + // There is no data to return at this moment. return nullptr; } @@ -433,6 +453,8 @@ RowVectorPtr TopNRowNumber::getOutput() { return nullptr; } + // All the input data is received, so the operator can start producing + // output. RowVectorPtr output; if (merge_ != nullptr) { output = getOutputFromSpill(); @@ -512,13 +534,15 @@ RowVectorPtr TopNRowNumber::getOutputFromMemory() { return output; } -bool TopNRowNumber::isNewPartition( +bool TopNRowNumber::compareSpillRowColumns( const RowVectorPtr& output, vector_size_t index, - SpillMergeStream* next) { + const SpillMergeStream* next, + vector_size_t startColumn, + vector_size_t endColumn) { VELOX_CHECK_GT(index, 0); - for (auto i = 0; i < numPartitionKeys_; ++i) { + for (auto i = startColumn; i < endColumn; ++i) { if (!output->childAt(inputChannels_[i]) ->equalValueAt( next->current().childAt(i).get(), @@ -530,22 +554,38 @@ bool TopNRowNumber::isNewPartition( return false; } -void TopNRowNumber::setupNextOutput( +// Compares the partition keys for new partitions. +bool TopNRowNumber::isNewPartition( const RowVectorPtr& output, - int32_t rowNumber) { - auto* lookAhead = merge_->next(); - if (lookAhead == nullptr) { - nextRowNumber_ = 0; + vector_size_t index, + const SpillMergeStream* next) { + return compareSpillRowColumns(output, index, next, 0, numPartitionKeys_); +} + +void TopNRowNumber::computeNextRankFromSpill( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next) { + if (isNewPartition(output, index, next)) { + nextRank_ = 1; return; } - if (isNewPartition(output, output->size(), lookAhead)) { - nextRowNumber_ = 0; + nextRank_ += 1; + return; +} + +void TopNRowNumber::setupNextOutput(const RowVectorPtr& output) { + auto resetNextRank = [this]() { nextRank_ = 1; }; + + auto* lookAhead = merge_->next(); + if (lookAhead == nullptr) { + resetNextRank(); return; } - nextRowNumber_ = rowNumber; - if (nextRowNumber_ < limit_) { + computeNextRankFromSpill(output, output->size(), lookAhead); + if (nextRank_ <= limit_) { return; } @@ -553,14 +593,14 @@ void TopNRowNumber::setupNextOutput( lookAhead->pop(); while (auto* next = merge_->next()) { if (isNewPartition(output, output->size(), next)) { - nextRowNumber_ = 0; + resetNextRank(); return; } next->pop(); } // This partition is the last partition. - nextRowNumber_ = 0; + resetNextRank(); } RowVectorPtr TopNRowNumber::getOutputFromSpill() { @@ -570,7 +610,7 @@ RowVectorPtr TopNRowNumber::getOutputFromSpill() { // All rows from the same partition will appear together. // We'll identify partition boundaries by comparing partition keys of the // current row with the previous row. When new partition starts, we'll reset - // row number to zero. Once row number reaches the 'limit_', we'll start + // nextRank_ to zero. Once rank reaches the 'limit_', we'll start // dropping rows until the next partition starts. // We'll emit output every time we accumulate 'outputBatchSize_' rows. @@ -583,24 +623,20 @@ RowVectorPtr TopNRowNumber::getOutputFromSpill() { // Index of the next row to append to output. vector_size_t index = 0; - - // Row number of the next row in the current partition. - vector_size_t rowNumber = nextRowNumber_; - VELOX_CHECK_LT(rowNumber, limit_); + VELOX_CHECK_LE(nextRank_, limit_); for (;;) { auto next = merge_->next(); if (next == nullptr) { break; } - // Check if this row comes from a new partition. - if (index > 0 && isNewPartition(output, index, next)) { - rowNumber = 0; + if (index > 0) { + computeNextRankFromSpill(output, index, next); } // Copy this row to the output buffer if this partition has // < limit_ rows output. - if (rowNumber < limit_) { + if (nextRank_ <= limit_) { for (auto i = 0; i < inputChannels_.size(); ++i) { output->childAt(inputChannels_[i]) ->copy( @@ -611,10 +647,9 @@ RowVectorPtr TopNRowNumber::getOutputFromSpill() { } if (rowNumbers) { // Row numbers start with 1. - rowNumbers->set(index, rowNumber + 1); + rowNumbers->set(index, nextRank_); } ++index; - ++rowNumber; } // Pop this row from the spill. @@ -625,8 +660,8 @@ RowVectorPtr TopNRowNumber::getOutputFromSpill() { // Prepare the next batch : // i) If 'limit_' is reached for this partition, then skip the rows // until the next partition. - // ii) If the next row is from a new partition, then reset rowNumber_. - setupNextOutput(output, rowNumber); + // ii) If the next row is from a new partition, then reset nextRank_. + setupNextOutput(output); return output; } } diff --git a/velox/exec/TopNRowNumber.h b/velox/exec/TopNRowNumber.h index dc21f0f93c47..6df6a910cd76 100644 --- a/velox/exec/TopNRowNumber.h +++ b/velox/exec/TopNRowNumber.h @@ -120,6 +120,12 @@ class TopNRowNumber : public Operator { // partitions left. TopRows* nextPartition(); + // Computes the rank for the next row to be output + // (all output rows in memory). + void computeNextRankInMemory( + const TopRows& partition, + vector_size_t rowIndex); + // Appends numRows of the output partition the output. Note: The rows are // popped in reverse order of the row_number. // NOTE: This function erases the yielded output rows from the partition @@ -150,17 +156,32 @@ class TopNRowNumber : public Operator { bool isNewPartition( const RowVectorPtr& output, vector_size_t index, - SpillMergeStream* next); + const SpillMergeStream* next); + + // Utility method to compare values from startColumn to endColumn for + // 'next' row from SpillMergeStream with current row of output (at index). + bool compareSpillRowColumns( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next, + vector_size_t startColumn, + vector_size_t endColumn); + + // Computes next rank value for spill output. + void computeNextRankFromSpill( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next); - // Sets nextRowNumber_ to rowNumber. Checks if next row in 'merge_' belongs to - // a different partition than last row in 'output' and if so updates - // nextRowNumber_ to 0. Also, checks current partition reached the limit on - // number of rows and if so advances 'merge_' to the first row on the next - // partition and sets nextRowNumber_ to 0. + // Checks if next row in 'merge_' belongs to a different partition than last + // row in 'output' and if so updates nextRank_ to 1. + // Also, checks current partition reached the limit on number of rows and + // if so advances 'merge_' to the first row on the next + // partition and sets nextRank_ to 1. // // @post 'merge_->next()' is either at end or points to a row that should be - // included in the next output batch using 'nextRowNumber_'. - void setupNextOutput(const RowVectorPtr& output, int32_t rowNumber); + // included in the next output batch using 'nextRank_'. + void setupNextOutput(const RowVectorPtr& output); // Called in noMoreInput() and spill(). void updateEstimatedOutputRowSize(); @@ -260,7 +281,8 @@ class TopNRowNumber : public Operator { // Used to sort-merge spilled data. std::unique_ptr> merge_; - // Row number for the first row in the next output batch from the spiller. - int32_t nextRowNumber_{0}; + // Row number (or rank or dense_rank in the future) for the next row being + // output from memory or the spiller. + vector_size_t nextRank_{1}; }; } // namespace facebook::velox::exec diff --git a/velox/exec/TraceUtil.cpp b/velox/exec/TraceUtil.cpp index f2f9c49d8aa1..a483bd0abf94 100644 --- a/velox/exec/TraceUtil.cpp +++ b/velox/exec/TraceUtil.cpp @@ -38,47 +38,6 @@ std::string findLastPathNode(const std::string& path) { return pathNodes.back(); } -const std::vector kEmptySources; - -class DummySourceNode final : public core::PlanNode { - public: - explicit DummySourceNode(RowTypePtr outputType) - : PlanNode(""), outputType_(std::move(outputType)) {} - - const RowTypePtr& outputType() const override { - return outputType_; - } - - const std::vector& sources() const override { - return kEmptySources; - } - - std::string_view name() const override { - return "DummySource"; - } - - folly::dynamic serialize() const override { - folly::dynamic obj = folly::dynamic::object; - obj["name"] = "DummySource"; - obj["outputType"] = outputType_->serialize(); - return obj; - } - - static core::PlanNodePtr create(const folly::dynamic& obj, void* context) { - return std::make_shared( - ISerializable::deserialize(obj["outputType"])); - } - - private: - void addDetails(std::stringstream& stream) const override { - // Nothing to add. - } - - const RowTypePtr outputType_; -}; - -void registerDummySourceSerDe(); - std::unordered_map& traceNodeRegistry() { static std::unordered_map registry; return registry; @@ -221,10 +180,8 @@ RowTypePtr getDataType( const core::PlanNodePtr& tracedPlan, const std::string& tracedNodeId, size_t sourceIndex) { - const auto* traceNode = core::PlanNode::findFirstNode( - tracedPlan.get(), [&tracedNodeId](const core::PlanNode* node) { - return node->id() == tracedNodeId; - }); + const auto* traceNode = + core::PlanNode::findNodeById(tracedPlan.get(), tracedNodeId); VELOX_CHECK_NOT_NULL( traceNode, "traced node id {} not found in the traced plan", @@ -306,9 +263,7 @@ bool canTrace(const std::string& operatorType) { core::PlanNodePtr getTraceNode( const core::PlanNodePtr& plan, core::PlanNodeId nodeId) { - const auto* traceNode = core::PlanNode::findFirstNode( - plan.get(), - [&nodeId](const core::PlanNode* node) { return node->id() == nodeId; }); + const auto* traceNode = core::PlanNode::findNodeById(plan.get(), nodeId); VELOX_CHECK_NOT_NULL(traceNode, "Failed to find node with id {}", nodeId); if (const auto* hashJoinNode = dynamic_cast(traceNode)) { @@ -407,7 +362,8 @@ core::PlanNodePtr getTraceNode( indexLookupJoinNode->leftKeys(), indexLookupJoinNode->rightKeys(), indexLookupJoinNode->joinConditions(), - indexLookupJoinNode->includeMatchColumn(), + indexLookupJoinNode->filter(), + indexLookupJoinNode->hasMarker(), std::make_shared( indexLookupJoinNode->sources().front()->outputType()), // Probe side indexLookupJoinNode->lookupSource(), // Index side @@ -446,7 +402,7 @@ core::PlanNodePtr getTraceNode( unnestNode->unnestVariables(), unnestNode->unnestNames(), unnestNode->ordinalityName(), - unnestNode->emptyUnnestValueName(), + unnestNode->markerName(), std::make_shared( unnestNode->sources().front()->outputType())); } diff --git a/velox/exec/TraceUtil.h b/velox/exec/TraceUtil.h index b9814e0e90d9..96837ff8b171 100644 --- a/velox/exec/TraceUtil.h +++ b/velox/exec/TraceUtil.h @@ -27,6 +27,45 @@ namespace facebook::velox::exec::trace { +static const std::vector kEmptySources; + +class DummySourceNode final : public core::PlanNode { + public: + explicit DummySourceNode(RowTypePtr outputType) + : PlanNode(""), outputType_(std::move(outputType)) {} + + const RowTypePtr& outputType() const override { + return outputType_; + } + + const std::vector& sources() const override { + return kEmptySources; + } + + std::string_view name() const override { + return "DummySource"; + } + + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = "DummySource"; + obj["outputType"] = outputType_->serialize(); + return obj; + } + + static core::PlanNodePtr create(const folly::dynamic& obj, void* context) { + return std::make_shared( + ISerializable::deserialize(obj["outputType"])); + } + + private: + void addDetails(std::stringstream& stream) const override { + // Nothing to add. + } + + const RowTypePtr outputType_; +}; + /// Creates a directory to store the query trace metdata and data. void createTraceDirectory( const std::string& traceDir, diff --git a/velox/exec/Unnest.cpp b/velox/exec/Unnest.cpp index 99aa13fdd3b4..2155ef616ac0 100644 --- a/velox/exec/Unnest.cpp +++ b/velox/exec/Unnest.cpp @@ -42,7 +42,7 @@ Unnest::Unnest( unnestNode->id(), "Unnest"), withOrdinality_(unnestNode->hasOrdinality()), - withEmptyUnnestValue_(unnestNode->hasEmptyUnnestValue()), + withMarker_(unnestNode->hasMarker()), maxOutputSize_( driverCtx->queryConfig().unnestSplitOutput() ? outputBatchRows() @@ -60,11 +60,11 @@ Unnest::Unnest( unnestDecoded_.resize(unnestVariables.size()); column_index_t checkOutputChannel = outputType_->size() - 1; - if (withEmptyUnnestValue_) { + if (withMarker_) { VELOX_CHECK_EQ( outputType_->childAt(checkOutputChannel), BOOLEAN(), - "Empty unnest value column should be BOOLEAN type."); + "Marker column should be BOOLEAN type."); --checkOutputChannel; } if (withOrdinality_) { @@ -194,7 +194,7 @@ Unnest::RowRange Unnest::extractRowRange(vector_size_t inputSize) const { if (rawMaxSizes_[inputRow] == 0) { VELOX_CHECK_EQ(remainingInnerRows, 0); hasEmptyUnnestValue = true; - if (withEmptyUnnestValue_) { + if (withMarker_) { remainingInnerRows = 1; } } @@ -237,12 +237,11 @@ void Unnest::generateRepeatedColumns( vector_size_t* rawRepeatedIndices = repeatedIndices->asMutable(); - const bool generateEmptyUnnestValue = - withEmptyUnnestValue_ && range.hasEmptyUnnestValue; + const bool generateMarker = withMarker_ && range.hasEmptyUnnestValue; vector_size_t index{0}; VELOX_CHECK_GT(range.numInputRows, 0); // Record the row number to process. - if (generateEmptyUnnestValue) { + if (generateMarker) { range.forEachRow( [&](vector_size_t row, vector_size_t /*start*/, vector_size_t size) { if (FOLLY_UNLIKELY(size == 0)) { @@ -302,7 +301,7 @@ const Unnest::UnnestChannelEncoding Unnest::generateEncodingForChannel( range.forEachRow( [&](vector_size_t row, vector_size_t start, vector_size_t size) { const auto end = start + size; - if (size == 0 && withEmptyUnnestValue_) { + if (size == 0 && withMarker_) { identityMapping = false; bits::setNull(rawNulls, index++, true); } else if (!currentDecoded.isNullAt(row)) { @@ -343,9 +342,8 @@ VectorPtr Unnest::generateOrdinalityVector(const RowRange& range) { // Set the ordinality at each result row to be the index of the element in // the original array (or map) plus one. auto* rawOrdinality = ordinalityVector->mutableRawValues(); - const bool hasEmptyUnnestValue = - withEmptyUnnestValue_ && range.hasEmptyUnnestValue; - if (!hasEmptyUnnestValue) { + const bool hasMarker = withMarker_ && range.hasEmptyUnnestValue; + if (!hasMarker) { range.forEachRow( [&](vector_size_t /*row*/, vector_size_t start, vector_size_t size) { std::iota(rawOrdinality, rawOrdinality + size, start + 1); @@ -374,28 +372,28 @@ VectorPtr Unnest::generateOrdinalityVector(const RowRange& range) { return ordinalityVector; } -VectorPtr Unnest::generateEmptyUnnestValueVector(const RowRange& range) { - VELOX_CHECK(withEmptyUnnestValue_); +VectorPtr Unnest::generateMarkerVector(const RowRange& range) { + VELOX_CHECK(withMarker_); VELOX_DCHECK_GT(range.numInputRows, 0); if (!range.hasEmptyUnnestValue) { return BaseVector::createConstant( - BOOLEAN(), false, range.numInnerRows, pool()); + BOOLEAN(), true, range.numInnerRows, pool()); } - // Create a vector with all elements set to false initially assuming most + // Create a vector with all elements set to true initially assuming most // output rows have non-empty unnest values. - auto emptyBuffer = - velox::AlignedBuffer::allocate(range.numInnerRows, pool(), false); - auto emptyVector = std::make_shared>( + auto markerBuffer = + velox::AlignedBuffer::allocate(range.numInnerRows, pool(), true); + auto markerVector = std::make_shared>( pool(), /*type=*/BOOLEAN(), /*nulls=*/nullptr, range.numInnerRows, - /*values=*/std::move(emptyBuffer), + /*values=*/std::move(markerBuffer), /*stringBuffers=*/std::vector{}); - // Set each output row has empty unnest values. - auto* const rawEmpty = emptyVector->mutableRawValues(); + // Set each output row with empty unnest values to false. + auto* const rawMarker = markerVector->mutableRawValues(); size_t index{0}; range.forEachRow( [&](vector_size_t /*row*/, vector_size_t start, vector_size_t size) { @@ -403,12 +401,12 @@ VectorPtr Unnest::generateEmptyUnnestValueVector(const RowRange& range) { index += size; } else { VELOX_DCHECK_EQ(size, 0); - bits::setBit(rawEmpty, index++, true); + bits::setBit(rawMarker, index++, false); } }, rawMaxSizes_, firstInnerRowStart_); - return emptyVector; + return markerVector; } RowVectorPtr Unnest::generateOutput(const RowRange& range) { @@ -443,8 +441,8 @@ RowVectorPtr Unnest::generateOutput(const RowRange& range) { if (withOrdinality_) { outputs[outputColumnIndex++] = generateOrdinalityVector(range); } - if (withEmptyUnnestValue_) { - outputs[outputColumnIndex++] = generateEmptyUnnestValueVector(range); + if (withMarker_) { + outputs[outputColumnIndex++] = generateMarkerVector(range); } return std::make_shared( diff --git a/velox/exec/Unnest.h b/velox/exec/Unnest.h index 58812e86ddd3..e1fb967525c3 100644 --- a/velox/exec/Unnest.h +++ b/velox/exec/Unnest.h @@ -134,15 +134,15 @@ class Unnest : public Operator { // Invoked by generateOutput for the ordinality column. VectorPtr generateOrdinalityVector(const RowRange& rowRange); - // Invoked by generateOutput for the empty unnest value column. - VectorPtr generateEmptyUnnestValueVector(const RowRange& rowRange); + // Invoked by generateOutput for the marker column. + VectorPtr generateMarkerVector(const RowRange& rowRange); // Invoked when finish one input batch processing to reset the internal // execution state for the next batch. void finishInput(); const bool withOrdinality_; - const bool withEmptyUnnestValue_; + const bool withMarker_; // The maximum number of output batch rows. const vector_size_t maxOutputSize_; diff --git a/velox/exec/VectorHasher.cpp b/velox/exec/VectorHasher.cpp index e502575aae7c..4dfb561111b9 100644 --- a/velox/exec/VectorHasher.cpp +++ b/velox/exec/VectorHasher.cpp @@ -41,6 +41,9 @@ namespace facebook::velox::exec { case TypeKind::BIGINT: { \ return TEMPLATE_FUNC(__VA_ARGS__); \ } \ + case TypeKind::HUGEINT: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ case TypeKind::VARCHAR: \ case TypeKind::VARBINARY: { \ return TEMPLATE_FUNC(__VA_ARGS__); \ @@ -736,6 +739,7 @@ void extendRange( extendRange(reserve, min, max); break; case TypeKind::BIGINT: + case TypeKind::HUGEINT: case TypeKind::VARCHAR: case TypeKind::VARBINARY: case TypeKind::TIMESTAMP: diff --git a/velox/exec/VectorHasher.h b/velox/exec/VectorHasher.h index ebd3534c1169..5449c6ca762c 100644 --- a/velox/exec/VectorHasher.h +++ b/velox/exec/VectorHasher.h @@ -291,6 +291,7 @@ class VectorHasher { case TypeKind::SMALLINT: case TypeKind::INTEGER: case TypeKind::BIGINT: + case TypeKind::HUGEINT: case TypeKind::VARCHAR: case TypeKind::VARBINARY: case TypeKind::TIMESTAMP: diff --git a/velox/exec/Window.cpp b/velox/exec/Window.cpp index 321a469ddae8..39ca0d644caf 100644 --- a/velox/exec/Window.cpp +++ b/velox/exec/Window.cpp @@ -64,6 +64,7 @@ Window::Window( driverCtx->queryConfig().prefixSortMaxStringPrefixLength()}, spillConfig, &nonReclaimableSection_, + &stats_, spillStats_.get()); } } diff --git a/velox/exec/Window.h b/velox/exec/Window.h index 8cbd3b1ee85c..563169ff4226 100644 --- a/velox/exec/Window.h +++ b/velox/exec/Window.h @@ -67,6 +67,11 @@ class Window : public Operator { void reclaim(uint64_t targetBytes, memory::MemoryReclaimer::Stats& stats) override; + /// Runtime statistics holding total number of batches read from spilled data. + /// 0 if no spilling occurred. + static inline const std::string kWindowSpillReadNumBatches{ + "windowSpillReadNumBatches"}; + private: // Used for k preceding/following frames. Index is the column index if k is a // column. value is used to read column values from the column index when k diff --git a/velox/exec/benchmarks/AtomicsBench.cpp b/velox/exec/benchmarks/AtomicsBench.cpp index 74343e9303f2..c86991ef6fcb 100644 --- a/velox/exec/benchmarks/AtomicsBench.cpp +++ b/velox/exec/benchmarks/AtomicsBench.cpp @@ -16,16 +16,39 @@ #include #include -#include #include #include #include +#include "velox/common/base/Portability.h" #include "velox/exec/OneWayStatusFlag.h" -using namespace ::testing; -using namespace facebook::velox; -static const size_t kNumThreads = 88; -static const size_t kNumIterations = 10000; +namespace { + +using facebook::velox::exec::OneWayStatusFlag; +constexpr size_t kNumThreads = 88; +constexpr size_t kNumIterations = 10000; + +#if defined(__x86_64__) && !defined(TSAN_BUILD) + +class OneWayStatusFlagUnsafe { + public: + bool check() const { + return fastStatus_ || atomicStatus_.load(); + } + + void set() { + if (!fastStatus_) { + atomicStatus_.store(true); + fastStatus_ = true; + } + } + + private: + bool fastStatus_{false}; + std::atomic_bool atomicStatus_{false}; +}; + +#endif void runParallelUpdates( std::function callback, @@ -46,12 +69,28 @@ void runParallelUpdates( } } -BENCHMARK(std_atomic_bool_write) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_write_seq_cst) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { flag.store(true); + bool dummy{}; + folly::doNotOptimizeAway(dummy); + } + }, + kNumThreads, // Threads + kNumIterations); // Iterations per thread +} + +BENCHMARK(std_atomic_bool_write_release) { + std::atomic_bool flag{false}; + runParallelUpdates( + [&](size_t iters) { + for (size_t i = 0; i < iters; ++i) { + flag.store(true, std::memory_order_release); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads @@ -59,76 +98,82 @@ BENCHMARK(std_atomic_bool_write) { } BENCHMARK(std_atomic_bool_write_relaxed) { - std::atomic flag{false}; + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { flag.store(true, std::memory_order_relaxed); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(std_atomic_bool_read_write_relaxed) { - std::atomic flag{false}; +BENCHMARK(one_way_flag_write) { + OneWayStatusFlag flag; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - if (!flag.load(std::memory_order_relaxed)) { - flag.store(true, std::memory_order_acq_rel); - } + flag.set(); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(one_way_flag_write) { - exec::OneWayStatusFlag flag; +#if defined(__x86_64__) && !defined(TSAN_BUILD) + +BENCHMARK(one_way_flag_unsafe_write) { + OneWayStatusFlagUnsafe flag; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { flag.set(); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } +#endif + // Read Benchmarks -BENCHMARK(std_atomic_bool_read) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_read_seq_cst) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - folly::doNotOptimizeAway(flag.load()); + folly::doNotOptimizeAway(flag.load(std::memory_order_seq_cst)); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(std_atomic_bool_relaxed_read) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_read_acquire) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - folly::doNotOptimizeAway(flag.load(std::memory_order_relaxed)); + folly::doNotOptimizeAway(flag.load(std::memory_order_acquire)); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(std_atomic_bool_read_relaxed_acquire) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_read_relaxed) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - folly::doNotOptimizeAway( - flag.load(std::memory_order_relaxed) || - flag.load(std::memory_order_acquire)); + folly::doNotOptimizeAway(flag.load(std::memory_order_relaxed)); } }, kNumThreads, // Threads @@ -136,7 +181,21 @@ BENCHMARK(std_atomic_bool_read_relaxed_acquire) { } BENCHMARK(one_way_flag_read) { - exec::OneWayStatusFlag flag; + OneWayStatusFlag flag; + runParallelUpdates( + [&](size_t iters) { + for (size_t i = 0; i < iters; ++i) { + folly::doNotOptimizeAway(flag.check()); + } + }, + kNumThreads, // Threads + kNumIterations); // Iterations per thread +} + +#if defined(__x86_64__) && !defined(TSAN_BUILD) + +BENCHMARK(one_way_flag_unsafe_read) { + OneWayStatusFlagUnsafe flag; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { @@ -147,6 +206,10 @@ BENCHMARK(one_way_flag_read) { kNumIterations); // Iterations per thread } +#endif + +} // namespace + int main(int argc, char** argv) { folly::Init init(&argc, &argv); folly::runBenchmarks(); diff --git a/velox/exec/benchmarks/CMakeLists.txt b/velox/exec/benchmarks/CMakeLists.txt index 5dfe0ff853e0..83830a0bee5a 100644 --- a/velox/exec/benchmarks/CMakeLists.txt +++ b/velox/exec/benchmarks/CMakeLists.txt @@ -158,3 +158,7 @@ target_link_libraries( velox_vector_fuzzer Folly::follybenchmark ) + +add_executable(velox_atomics_benchmark AtomicsBench.cpp) + +target_link_libraries(velox_atomics_benchmark Folly::follybenchmark) diff --git a/velox/exec/benchmarks/MergeBenchmark.cpp b/velox/exec/benchmarks/MergeBenchmark.cpp index f5fbaee0ff6f..118007abe0d4 100644 --- a/velox/exec/benchmarks/MergeBenchmark.cpp +++ b/velox/exec/benchmarks/MergeBenchmark.cpp @@ -19,7 +19,7 @@ #include -#include "velox/exec/TreeOfLosers.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/exec/tests/utils/MergeTestBase.h" using namespace facebook::velox; diff --git a/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp b/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp index 27025f723515..1a4ebdd383ab 100644 --- a/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp +++ b/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp @@ -192,7 +192,8 @@ class WindowPrefixSortBenchmark : public HiveConnectorTestBase { std::move(plan), 0, core::QueryCtx::create(executor_.get()), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); } else { const std::unordered_map queryConfigMap( {{core::QueryConfig::kPrefixSortNormalizedKeyMaxBytes, "0"}}); @@ -202,7 +203,8 @@ class WindowPrefixSortBenchmark : public HiveConnectorTestBase { 0, core::QueryCtx::create( executor_.get(), core::QueryConfig(queryConfigMap)), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); } } diff --git a/velox/exec/fuzzer/CMakeLists.txt b/velox/exec/fuzzer/CMakeLists.txt index a826e6785adf..d030d0f8fc12 100644 --- a/velox/exec/fuzzer/CMakeLists.txt +++ b/velox/exec/fuzzer/CMakeLists.txt @@ -198,6 +198,53 @@ target_link_libraries( velox_vector_fuzzer ) +# LocalRunnerService (requires FBThrift support) +if(VELOX_ENABLE_REMOTE_FUNCTIONS) + # Generate Thrift library for LocalRunnerService + include(FBThriftCppLibrary) + add_fbthrift_cpp_library( + local_runner_service_thrift + if/LocalRunnerService.thrift + SERVICES + LocalRunnerService + ) + + target_compile_options(local_runner_service_thrift PRIVATE -Wno-error=deprecated-declarations) + + # LocalRunnerService Library + add_library(velox_local_runner_service_lib LocalRunnerService.cpp) + + target_link_libraries( + velox_local_runner_service_lib + local_runner_service_thrift + velox_core + velox_exec + velox_exec_test_lib + velox_expression + velox_functions_prestosql + velox_common_base + velox_memory + Folly::folly + FBThrift::thriftcpp2 + gflags + glog::glog + ) + + target_include_directories(velox_local_runner_service_lib PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + + # LocalRunnerService Executable + add_executable(velox_local_runner_service_runner LocalRunnerServiceRunner.cpp) + + target_link_libraries( + velox_local_runner_service_runner + velox_local_runner_service_lib + velox_functions_prestosql + gtest + gflags + Folly::folly + ) +endif() + if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() diff --git a/velox/exec/fuzzer/LocalRunnerService.cpp b/velox/exec/fuzzer/LocalRunnerService.cpp new file mode 100644 index 000000000000..66f82e732b24 --- /dev/null +++ b/velox/exec/fuzzer/LocalRunnerService.cpp @@ -0,0 +1,363 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "velox/core/QueryCtx.h" +#include "velox/exec/fuzzer/LocalRunnerService.h" +#include "velox/exec/fuzzer/if/gen-cpp2/LocalRunnerService.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/expression/EvalCtx.h" + +using namespace facebook::velox; +using namespace facebook::velox::runner; + +namespace facebook::velox::runner { +namespace { + +class StdoutCapture { + public: + StdoutCapture() { + oldCoutBuf_ = std::cout.rdbuf(); + std::cout.rdbuf(buffer_.rdbuf()); + } + ~StdoutCapture() { + std::cout.rdbuf(oldCoutBuf_); + } + std::string str() { + return buffer_.str(); + } + + private: + std::stringstream buffer_; + std::streambuf* oldCoutBuf_; +}; + +std::pair execute( + const std::string& serializedPlan, + const std::string& queryId, + std::shared_ptr pool) { + StdoutCapture stdoutCapture; + + core::PlanNodePtr plan; + try { + folly::dynamic planJson = folly::parseJson(serializedPlan); + VLOG(1) << "Deserializing plan:\n" << serializedPlan; + plan = core::PlanNode::deserialize(planJson, pool.get()); + } catch (const std::exception& e) { + throw std::runtime_error( + fmt::format("Failed to deserialize plan: {}", e.what())); + } + VLOG(1) << "Deserialized plan:\n" << plan->toString(true, true); + + try { + exec::test::AssertQueryBuilder queryBuilder(plan); + + std::shared_ptr task; + auto results = queryBuilder.copyResults(pool.get(), task); + + return {results, stdoutCapture.str()}; + } catch (const std::exception& e) { + throw std::runtime_error( + fmt::format("Error executing query: {}", e.what())); + } +} + +} // namespace + +ScalarValue getScalarValue(const VectorPtr& vector, vector_size_t rowIdx) { + ScalarValue scalar; + + switch (vector->typeKind()) { + case TypeKind::BOOLEAN: + scalar.boolValue_ref() = + vector->as>()->valueAt(rowIdx); + break; + case TypeKind::TINYINT: + scalar.tinyintValue_ref() = + vector->as>()->valueAt(rowIdx); + break; + case TypeKind::SMALLINT: + scalar.smallintValue_ref() = + vector->as>()->valueAt(rowIdx); + break; + case TypeKind::INTEGER: + scalar.integerValue_ref() = + vector->as>()->valueAt(rowIdx); + break; + case TypeKind::BIGINT: + scalar.bigintValue_ref() = + vector->as>()->valueAt(rowIdx); + break; + case TypeKind::REAL: + scalar.realValue_ref() = + vector->as>()->valueAt(rowIdx); + break; + case TypeKind::DOUBLE: + scalar.doubleValue_ref() = + vector->as>()->valueAt(rowIdx); + break; + case TypeKind::VARCHAR: + scalar.varcharValue_ref() = + vector->as>()->valueAt(rowIdx).str(); + break; + case TypeKind::VARBINARY: { + const auto& binValue = + vector->as>()->valueAt(rowIdx); + scalar.varbinaryValue_ref() = + std::string(binValue.data(), binValue.size()); + break; + } + case TypeKind::TIMESTAMP: { + const auto& ts = + vector->as>()->valueAt( + rowIdx); + facebook::velox::runner::Timestamp timestampValue; + timestampValue.seconds_ref() = ts.getSeconds(); + timestampValue.nanos_ref() = ts.getNanos(); + scalar.timestampValue_ref() = std::move(timestampValue); + break; + } + case TypeKind::HUGEINT: { + const auto& hugeint = + vector->as>()->valueAt(rowIdx); + facebook::velox::runner::i128 hugeintValue; + hugeintValue.msb_ref() = static_cast(hugeint >> 64); + hugeintValue.lsb_ref() = + static_cast(hugeint & 0xFFFFFFFFFFFFFFFFULL); + scalar.hugeintValue_ref() = std::move(hugeintValue); + break; + } + default: + VELOX_FAIL(fmt::format("Unsupported scalar type: {}", vector->type())); + } + + return scalar; +} + +ComplexValue getComplexValue( + const VectorPtr& vector, + vector_size_t rowIdx, + const exec::EvalCtx& evalCtx) { + ComplexValue complex; + + exec::LocalDecodedVector decoder( + evalCtx, *vector, SelectivityVector(vector->size())); + auto& decoded = *decoder.get(); + rowIdx = decoded.index(rowIdx); + + switch (vector->typeKind()) { + case TypeKind::ARRAY: { + auto arrayVector = decoded.base()->as(); + auto elements = arrayVector->elements(); + auto offset = arrayVector->offsetAt(rowIdx); + auto size = arrayVector->sizeAt(rowIdx); + + facebook::velox::runner::Array arrayValue; + + for (auto i = 0; i < size; ++i) { + auto elementIdx = offset + i; + + Value elementValue; + if (elements->isNullAt(elementIdx)) { + elementValue.isNull() = true; + } else { + elementValue = convertValue(elements, elementIdx, evalCtx); + } + arrayValue.values()->push_back(std::move(elementValue)); + } + + complex.arrayValue_ref() = std::move(arrayValue); + break; + } + case TypeKind::MAP: { + auto mapVector = decoded.base()->as(); + auto keys = mapVector->mapKeys(); + auto values = mapVector->mapValues(); + auto offset = mapVector->offsetAt(rowIdx); + auto size = mapVector->sizeAt(rowIdx); + + facebook::velox::runner::Map mapValue; + + for (auto i = 0; i < size; ++i) { + Value keyValue, valueValue; + + VELOX_CHECK(!(keys->isNullAt(offset + i)), "Map key cannot be null"); + keyValue = convertValue(keys, offset + i, evalCtx); + if (values->isNullAt(offset + i)) { + valueValue.isNull() = true; + } else { + valueValue = convertValue(values, offset + i, evalCtx); + } + (*mapValue.values())[std::move(keyValue)] = std::move(valueValue); + } + + complex.mapValue_ref() = std::move(mapValue); + break; + } + case TypeKind::ROW: { + auto rowVector = decoded.base()->as(); + facebook::velox::runner::Row rowValue; + + for (auto i = 0; i < rowVector->childrenSize(); ++i) { + auto childVector = rowVector->childAt(i); + + Value fieldValue; + if (childVector->isNullAt(rowIdx)) { + fieldValue.isNull() = true; + } else { + fieldValue = convertValue(childVector, rowIdx, evalCtx); + } + rowValue.fieldValues()->push_back(std::move(fieldValue)); + } + + complex.rowValue_ref() = std::move(rowValue); + break; + } + default: + VELOX_FAIL(fmt::format("Unsupported complex type: {}", vector->type())); + } + + return complex; +} + +Value convertValue( + const VectorPtr& vector, + vector_size_t rowIdx, + const exec::EvalCtx& evalCtx) { + Value value; + if (vector->isNullAt(rowIdx)) { + value.isNull() = true; + } else { + value.isNull() = false; + switch (vector->typeKind()) { + case TypeKind::BOOLEAN: + case TypeKind::TINYINT: + case TypeKind::SMALLINT: + case TypeKind::INTEGER: + case TypeKind::BIGINT: + case TypeKind::REAL: + case TypeKind::DOUBLE: + case TypeKind::VARCHAR: + case TypeKind::VARBINARY: + case TypeKind::TIMESTAMP: + case TypeKind::HUGEINT: + value.scalarValue_ref() = getScalarValue(vector, rowIdx); + break; + case TypeKind::ARRAY: + case TypeKind::MAP: + case TypeKind::ROW: + value.complexValue_ref() = getComplexValue(vector, rowIdx, evalCtx); + break; + default: + VELOX_FAIL(fmt::format("Unsupported type: {}", vector->type())); + } + } + return value; +} + +std::vector convertVector( + const VectorPtr& vector, + vector_size_t size, + const exec::EvalCtx& evalCtx) { + std::vector rows; + for (vector_size_t rowIdx = 0; rowIdx < size; ++rowIdx) { + Value value = convertValue(vector, rowIdx, evalCtx); + rows.push_back(value); + } + return rows; +} + +std::vector convertToBatches( + const std::vector& rowVectors, + const exec::EvalCtx& evalCtx) { + std::vector results; + + if (rowVectors.empty()) { + return results; + } + + for (const auto& rowVector : rowVectors) { + Batch result; + const auto& rowType = rowVector->type()->asRow(); + + for (auto i = 0; i < rowType.size(); ++i) { + result.columnNames()->push_back(rowType.nameOf(i)); + result.columnTypes()->push_back(rowType.childAt(i)->toString()); + } + + result.numRows() = rowVector->size(); + + const auto numColumns = rowVector->childrenSize(); + result.columns()->resize(numColumns); + + for (auto colIdx = 0; colIdx < numColumns; ++colIdx) { + (*result.columns())[colIdx].rows() = + convertVector(rowVector->childAt(colIdx), rowVector->size(), evalCtx); + } + + results.push_back(std::move(result)); + } + + return results; +} + +void LocalRunnerServiceHandler::execute( + ExecutePlanResponse& response, + std::unique_ptr request) { + VLOG(1) << "Received executePlan request"; + + std::shared_ptr pool = + memory::memoryManager()->addLeafPool(); + + RowVectorPtr results; + std::string output; + + try { + VLOG(1) << "Executing plan in service handler"; + std::tie(results, output) = + ::execute(*request->serializedPlan(), *request->queryId(), pool); + + VLOG(1) << fmt::format( + "Result:\nresult rowVector: {}\nstdout: {}", + results->toString(true), + output); + } catch (const std::exception& e) { + VLOG(1) << "Exception executing plan: " << e.what(); + response.success() = false; + response.errorMessage() = e.what(); + return; + } + + auto queryCtx = core::QueryCtx::create(); + core::ExecCtx execCtx(pool.get(), queryCtx.get()); + exec::EvalCtx evalCtx(&execCtx); + + VLOG(1) << "Converting results to Thrift response"; + auto resultBatches = convertToBatches({results}, evalCtx); + response.results() = std::move(resultBatches); + response.output() = output; + response.success() = true; + VLOG(1) << "Response sent"; +} + +} // namespace facebook::velox::runner diff --git a/velox/exec/fuzzer/LocalRunnerService.h b/velox/exec/fuzzer/LocalRunnerService.h new file mode 100644 index 000000000000..8d9baf2ca6a2 --- /dev/null +++ b/velox/exec/fuzzer/LocalRunnerService.h @@ -0,0 +1,85 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// Thrift service implementation and library for executing Velox query plans +/// remotely. +/// +/// This file provides conversion utilities and a service handler for the +/// LocalRunnerService. It enables remote execution of serialized Velox +/// expression evaluation primarily used for fuzzing where query plans need to +/// be executed on remote workers. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "velox/exec/fuzzer/if/gen-cpp2/LocalRunnerService.h" +#include "velox/expression/EvalCtx.h" + +namespace facebook::velox::runner { + +/// Extracts a scalar (primitive) value from a Velox vector at the specified +/// row. This function handles all primitive types supported by Velox, such as: +/// TINYINT, INTEGER, BIGINT, etc. +ScalarValue getScalarValue(const VectorPtr& vector, vector_size_t rowIdx); + +/// Extracts a complex (nested) value from a Velox vector at the specified row. +/// This function handles all complex types supported by Velox: ARRAY, MAP and +/// ROW. The function recursively converts nested structures. +ComplexValue getComplexValue( + const VectorPtr& vector, + vector_size_t rowIdx, + const exec::EvalCtx& evalCtx); + +/// Converts a Velox vector value at a specific row to a Thrift Value and serves +/// as the entry point for value conversion that can be either primitive or +/// complex. Output value can either be a scalar or complex value (as mentioned, +/// using the above). This is where NULL is also defined. +Value convertValue( + const VectorPtr& vector, + vector_size_t rowIdx, + const exec::EvalCtx& evalCtx); + +/// Converts a Velox vector into a corresponding Thrift struct vector. +std::vector convertVector( + const VectorPtr& vector, + vector_size_t size, + const exec::EvalCtx& evalCtx); + +/// Converts a collection of Velox RowVectors into Thrift Batches. +std::vector convertToBatches( + const std::vector& rowVectors, + const exec::EvalCtx& evalCtx); + +/// Thrift service handler for executing Velox query plans. +/// Executes a serialized Velox query plan. This method deserializes the plan +/// from JSON, configures execution, runs the query plan to completion, +/// converts results to Thrift Batches and captures any subsequent errors or +/// output. The method returns a Thrift response containing the results. +class LocalRunnerServiceHandler + : public apache::thrift::ServiceHandler { + public: + void execute( + ExecutePlanResponse& response, + std::unique_ptr request) override; +}; + +} // namespace facebook::velox::runner diff --git a/velox/exec/fuzzer/LocalRunnerServiceRunner.cpp b/velox/exec/fuzzer/LocalRunnerServiceRunner.cpp new file mode 100644 index 000000000000..7bb15f9d355a --- /dev/null +++ b/velox/exec/fuzzer/LocalRunnerServiceRunner.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "velox/core/ITypedExpr.h" +#include "velox/core/PlanNode.h" +#include "velox/exec/fuzzer/LocalRunnerService.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/type/Type.h" + +using namespace facebook::velox; +using namespace facebook::velox::runner; + +DEFINE_int32( + port, + 9091, + "LocalRunnerService port number to be used in conjunction with ExpressionFuzzerTest flag local_runner_port."); + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + folly::Init init(&argc, &argv); + + memory::initializeMemoryManager(memory::MemoryManager::Options{}); + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + functions::prestosql::registerAllScalarFunctions(); + functions::prestosql::registerInternalFunctions(); + + std::shared_ptr thriftServer = + std::make_shared(); + thriftServer->setPort(FLAGS_port); + thriftServer->setInterface(std::make_shared()); + thriftServer->setNumIOWorkerThreads(1); + thriftServer->setNumCPUWorkerThreads(1); + + VLOG(1) << "Starting LocalRunnerService"; + thriftServer->serve(); + + return 0; +} diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index 27b524a3af57..b85a2885eb96 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -318,17 +318,35 @@ bool PrestoQueryRunner::isConstantExprSupported( } bool PrestoQueryRunner::isSupported(const exec::FunctionSignature& signature) { - // TODO: support queries with these types. Among the types below, hugeint is - // not a native type in Presto, so fuzzer should not use it as the type of - // cast-to or constant literals. Hyperloglog and TDigest can only be casted - // from varbinary and cannot be used as the type of constant literals. - // Interval year to month can only be casted from NULL and cannot be used as - // the type of constant literals. Json, Ipaddress, Ipprefix, and UUID require - // special handling, because Presto requires literals of these types to be - // valid, and doesn't allow creating HIVE columns of these types. + // TODO: support queries with these types. + // Types not supported by PrestoQueryRunner and their reasons: + // + // hugeint: + // - Not a native type in Presto + // - Fuzzer should not use it for cast-to or constant literals + // + // interval year to month: + // - Can only be casted from NULL + // - Cannot be used as constant literal types + // + // ipaddress, ipprefix, uuid: + // - Require special handling in Presto + // - Presto requires literals of these types to be valid + // - Cannot create HIVE columns of these types + // + // geometry: + // - Under development in Presto + // - Cannot be used as constant literals + // - Expected differences between Presto Java and Velox C++ implementations + // + // p4hyperloglog: + // - Not a native type in Presto + // - Cannot create HIVE columns of these types return !( usesTypeName(signature, "interval year to month") || usesTypeName(signature, "hugeint") || + usesTypeName(signature, "geometry") || usesTypeName(signature, "time") || + usesTypeName(signature, "p4hyperloglog") || usesInputTypeName(signature, "ipaddress") || usesInputTypeName(signature, "ipprefix") || usesInputTypeName(signature, "uuid")); diff --git a/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp b/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp index 059263c99412..1508d1e2a28c 100644 --- a/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp @@ -66,6 +66,9 @@ intermediateTypeTransforms() { std::make_shared( SFMSKETCH(), VARBINARY())}, {JSON(), std::make_shared()}, + {TIME(), + std::make_shared( + TIME(), VARCHAR())}, {BINGTILE(), std::make_shared( BINGTILE(), BIGINT())}, diff --git a/velox/exec/fuzzer/WriterFuzzer.cpp b/velox/exec/fuzzer/WriterFuzzer.cpp index 9c2132c7403e..61e65a1021ae 100644 --- a/velox/exec/fuzzer/WriterFuzzer.cpp +++ b/velox/exec/fuzzer/WriterFuzzer.cpp @@ -16,6 +16,7 @@ #include "velox/exec/fuzzer/WriterFuzzer.h" #include +#include #include #include diff --git a/velox/exec/fuzzer/if/LocalRunnerService.thrift b/velox/exec/fuzzer/if/LocalRunnerService.thrift new file mode 100644 index 000000000000..a0301c1a3423 --- /dev/null +++ b/velox/exec/fuzzer/if/LocalRunnerService.thrift @@ -0,0 +1,128 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file defines a Thrift service for executing Velox query plans remotely. +// It provides a type system for representing query results and a service interface +// for executing serialized query plans with configurable parallelism. + +namespace cpp2 facebook.velox.runner + +// Represents a HUGEINT value by splitting it into most significant and least +// significant components. +struct i128 { + 1: i64 msb; + 2: i64 lsb; +} + +// Represents a timestamp value with seconds and nanoseconds components. +struct Timestamp { + 1: i64 seconds; + 2: i64 nanos; +} + +// A tagged union representing all supported scalar (primitive) data types. +// Only one field will be set at a time, corresponding to the actual type of the value. +union ScalarValue { + 1: bool boolValue; + 2: byte tinyintValue; + 3: i16 smallintValue; + 4: i32 integerValue; + 5: i64 bigintValue; + 6: float realValue; + 7: double doubleValue; + 8: string varcharValue; + 9: binary varbinaryValue; + 10: Timestamp timestampValue; + 11: i128 hugeintValue; +} + +// Represents an ARRAY type, containing an ordered list of values. +// All values in the array are of the same type. +struct Array { + 1: list values; +} + +// Represents a MAP type, containing key-value pairs. +// Keys and values can be of any supported type. +struct Map { + 1: map values; +} + +// Represents a ROW (struct) type, containing an ordered list of field values. +// Each field can have a different type. +struct Row { + 1: list fieldValues; +} + +// A tagged union representing complex (nested) data types. +// Only one field will be set at a time, corresponding to the actual complex type. +union ComplexValue { + 1: Array arrayValue; + 2: Map mapValue; + 3: Row rowValue; +} + +// Represents a single value of any supported data type. +// A value can be: +// - A scalar (primitive) value +// - A complex (nested) value +// - NULL (indicated by isNull = true) +struct Value { + 1: optional ScalarValue scalarValue; + 2: optional ComplexValue complexValue; + 3: bool isNull; +} + +// Represents a single column of data in columnar format. +// Contains all values for one column across multiple rows. +struct Column { + 1: list rows; +} + +// Represents a batch of rows in columnar format. +// This is the fundamental unit of data transfer, containing multiple columns +// and metadata about the schema. +struct Batch { + 1: list columns; + 2: list columnNames; + 3: list columnTypes; + 4: i32 numRows; +} + +// Request to execute a serialized Velox query plan. +struct ExecutePlanRequest { + 1: string serializedPlan; + 2: string queryId; + 3: i32 numWorkers = 4; + 4: i32 numDrivers = 2; +} + +// Response from executing a query plan. +struct ExecutePlanResponse { + 1: list results; + 2: string output; + 3: bool success; + 4: optional string errorMessage; +} + +// Service for executing Velox query plans locally. +// This service enables remote execution of serialized query plans with +// configurable parallelism, returning results in a structured format. +service LocalRunnerService { + // Inputs a Thrift request and executes a serialized Velox query plan and + // returns the results as a Thrift response. + ExecutePlanResponse execute(1: ExecutePlanRequest request); +} diff --git a/velox/exec/fuzzer/tests/CMakeLists.txt b/velox/exec/fuzzer/tests/CMakeLists.txt index b7478d931a2e..a77e6691a5cc 100644 --- a/velox/exec/fuzzer/tests/CMakeLists.txt +++ b/velox/exec/fuzzer/tests/CMakeLists.txt @@ -16,3 +16,24 @@ add_executable(presto_sql_test PrestoSqlTest.cpp) add_test(presto_sql_test presto_sql_test) target_link_libraries(presto_sql_test velox_fuzzer_util velox_presto_types) + +# LocalRunnerService Test (requires FBThrift support) +if(VELOX_ENABLE_REMOTE_FUNCTIONS) + add_executable(local_runner_service_test LocalRunnerServiceTest.cpp) + add_test(local_runner_service_test local_runner_service_test) + + target_link_libraries( + local_runner_service_test + velox_local_runner_service_lib + local_runner_service_thrift + velox_core + velox_type + velox_functions_prestosql + velox_functions_test_lib + velox_vector_test_lib + velox_common_base + Folly::folly + gtest + gtest_main + ) +endif() diff --git a/velox/exec/fuzzer/tests/LocalRunnerServiceTest.cpp b/velox/exec/fuzzer/tests/LocalRunnerServiceTest.cpp new file mode 100644 index 000000000000..c5e8cf7f30ec --- /dev/null +++ b/velox/exec/fuzzer/tests/LocalRunnerServiceTest.cpp @@ -0,0 +1,299 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "velox/common/memory/Memory.h" +#include "velox/core/PlanNode.h" +#include "velox/exec/fuzzer/LocalRunnerService.h" +#include "velox/exec/fuzzer/if/gen-cpp2/LocalRunnerService.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/type/Type.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::velox; +using namespace facebook::velox::runner; +using namespace facebook::velox::test; + +namespace facebook::velox::fuzzer::test { +class LocalRunnerServiceTest : public functions::test::FunctionBaseTest { + protected: + void SetUp() override { + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + functions::prestosql::registerAllScalarFunctions(); + functions::prestosql::registerInternalFunctions(); + + createTestData(); + } + + void createTestData() { + // Create test vectors for different data types + auto rowType = ROW({ + {"bool_col", BOOLEAN()}, + {"int_col", INTEGER()}, + {"bigint_col", BIGINT()}, + {"double_col", DOUBLE()}, + {"varchar_col", VARCHAR()}, + {"timestamp_col", TIMESTAMP()}, + {"array_col", ARRAY(ARRAY(INTEGER()))}, + }); + + testRowVector_ = makeRowVector( + {"bool_col", + "int_col", + "bigint_col", + "double_col", + "varchar_col", + "timestamp_col", + "array_col"}, + { + makeFlatVector( + 10, + [](auto row) { return row % 2 == 0; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return row * 1.1; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return fmt::format("str_{}", row); }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return facebook::velox::Timestamp(row, 0); }, + [](auto row) { return row % 3 == 0; }), + makeNestedArrayVectorFromJson( + {"[[1, 2]]", + "[[3]]", + "[[4]]", + "[[5, 6]]", + "[[7]]", + "[[8]]", + "[[9]]", + "[[10]]", + "[[11]]", + "[[12]]"}), + }); + + testRowVectorWrapped_ = makeRowVector( + {"bool_col", + "int_col", + "bigint_col", + "double_col", + "varchar_col", + "timestamp_col", + "array_col"}, + { + makeFlatVector( + 5, + [](auto row) { return row % 2 == 0; }, + [](auto row) { return row % 3 == 0; }), + wrapInDictionary( + makeIndices(5, [](auto row) { return (row * 17 + 3) % 10; }), + 5, + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; })), + BaseVector::wrapInConstant( + 5, + 0, + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; })), + makeFlatVector( + 5, + [](auto row) { return row * 1.1; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 5, + [](auto row) { return fmt::format("str_{}", row); }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 5, + [](auto row) { return facebook::velox::Timestamp(row, 0); }, + [](auto row) { return row % 3 == 0; }), + makeNestedArrayVectorFromJson( + {"[[1, 2]]", + "[[3]]", + "[[4]]", + "[[5, 6]]", + "[[7]]", + "[[8]]", + "[[9]]", + "[[10]]", + "[[11]]", + "[[12]]"}), + }); + } + + RowVectorPtr testRowVector_; + RowVectorPtr testRowVectorWrapped_; +}; + +TEST_F(LocalRunnerServiceTest, ConvertToBatches) { + auto queryCtx = core::QueryCtx::create(); + core::ExecCtx execCtx(rootPool_.get(), queryCtx.get()); + exec::EvalCtx evalCtx(&execCtx); + auto result = + facebook::velox::runner::convertToBatches({testRowVector_}, evalCtx); + + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0].columnNames()->size(), 7); + ASSERT_EQ(result[0].columnTypes()->size(), 7); + ASSERT_EQ(result[0].numRows(), 10); + ASSERT_EQ(result[0].columns()->size(), 7); + ASSERT_EQ(result[0].columns()[0].rows()->size(), 10); + ASSERT_EQ(result[0].columns()[1].rows()->size(), 10); + ASSERT_EQ(result[0].columns()[2].rows()->size(), 10); + ASSERT_EQ(result[0].columns()[3].rows()->size(), 10); + ASSERT_EQ(result[0].columns()[4].rows()->size(), 10); + ASSERT_EQ(result[0].columns()[5].rows()->size(), 10); +} + +TEST_F(LocalRunnerServiceTest, ConvertToBatchesWrapped) { + auto queryCtx = core::QueryCtx::create(); + core::ExecCtx execCtx(rootPool_.get(), queryCtx.get()); + exec::EvalCtx evalCtx(&execCtx); + auto result = facebook::velox::runner::convertToBatches( + {testRowVectorWrapped_}, evalCtx); + + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0].columnNames()->size(), 7); + ASSERT_EQ(result[0].columnTypes()->size(), 7); + ASSERT_EQ(result[0].numRows(), 5); + ASSERT_EQ(result[0].columns()->size(), 7); + ASSERT_EQ(result[0].columns()[0].rows()->size(), 5); + ASSERT_EQ(result[0].columns()[1].rows()->size(), 5); + ASSERT_EQ(result[0].columns()[2].rows()->size(), 5); + ASSERT_EQ(result[0].columns()[3].rows()->size(), 5); + ASSERT_EQ(result[0].columns()[4].rows()->size(), 5); + ASSERT_EQ(result[0].columns()[5].rows()->size(), 5); + + auto kWrappedConstantIndex = 2; + ASSERT_EQ( + result[0].columns()[kWrappedConstantIndex].rows()[0], + result[0].columns()[kWrappedConstantIndex].rows()[1]); + ASSERT_EQ( + result[0].columns()[kWrappedConstantIndex].rows()[0], + result[0].columns()[kWrappedConstantIndex].rows()[2]); + ASSERT_EQ( + result[0].columns()[kWrappedConstantIndex].rows()[0], + result[0].columns()[kWrappedConstantIndex].rows()[3]); + ASSERT_EQ( + result[0].columns()[kWrappedConstantIndex].rows()[0], + result[0].columns()[kWrappedConstantIndex].rows()[4]); + + auto reference = + facebook::velox::runner::convertToBatches({testRowVector_}, evalCtx); + auto kWrappedDictionaryIndex = 1; + ASSERT_EQ( + result[0].columns()[kWrappedDictionaryIndex].rows()[0], + reference[0].columns()[kWrappedDictionaryIndex].rows()[3]); + ASSERT_EQ( + result[0].columns()[kWrappedDictionaryIndex].rows()[1], + reference[0].columns()[kWrappedDictionaryIndex].rows()[0]); + ASSERT_EQ( + result[0].columns()[kWrappedDictionaryIndex].rows()[2], + reference[0].columns()[kWrappedDictionaryIndex].rows()[7]); + ASSERT_EQ( + result[0].columns()[kWrappedDictionaryIndex].rows()[3], + reference[0].columns()[kWrappedDictionaryIndex].rows()[4]); + ASSERT_EQ( + result[0].columns()[kWrappedDictionaryIndex].rows()[4], + reference[0].columns()[kWrappedDictionaryIndex].rows()[1]); + ASSERT_NE( + result[0].columns()[kWrappedDictionaryIndex].rows()[0], + reference[0].columns()[kWrappedDictionaryIndex].rows()[1]); +} + +TEST_F(LocalRunnerServiceTest, ServiceHandlerMockRequestIntegration) { + LocalRunnerServiceHandler handler; + + auto request = std::make_unique(); + // Serialized plan for the following: + // expressions: (p0:DOUBLE, plus(null,0.1646418017335236)) + request->serializedPlan() = + R"({"names":["p0","p1"],"id":"project","name":"ProjectNode","sources":[{"name":"ProjectNode","id":"transform","projections":[{"name":"FieldAccessTypedExpr","type":{"name":"Type","type":"BIGINT"},"inputs":[{"name":"InputTypedExpr","type":{"type":"ROW","name":"Type","names":["row_number"],"cTypes":[{"name":"Type","type":"BIGINT"}]}}],"fieldName":"row_number"}],"names":["row_number"],"sources":[{"name":"ValuesNode","id":"efb6650a_8541_4214_82dd_9792a4965380","data":"AAAAAF4AAAB7ImNUeXBlcyI6W3sidHlwZSI6IkJJR0lOVCIsIm5hbWUiOiJUeXBlIn1dLCJuYW1lcyI6WyJyb3dfbnVtYmVyIl0sInR5cGUiOiJST1ciLCJuYW1lIjoiVHlwZSJ9AQAAAAABAAAAAQAAAAAfAAAAeyJ0eXBlIjoiQklHSU5UIiwibmFtZSI6IlR5cGUifQEAAAAAAQgAAAAAAAAAAAAAAA==","parallelizable":false,"repeatTimes":1}]}],"projections":[{"name":"CallTypedExpr","type":{"name":"Type","type":"DOUBLE"},"functionName":"plus","inputs":[{"name":"ConstantTypedExpr","type":{"name":"Type","type":"DOUBLE"},"valueVector":"AQAAAB8AAAB7InR5cGUiOiJET1VCTEUiLCJuYW1lIjoiVHlwZSJ9AQAAAAE="},{"name":"ConstantTypedExpr","type":{"name":"Type","type":"DOUBLE"},"valueVector":"AQAAAB8AAAB7InR5cGUiOiJET1VCTEUiLCJuYW1lIjoiVHlwZSJ9AQAAAAABAAAAifsSxT8="}]},{"name":"FieldAccessTypedExpr","type":{"name":"Type","type":"BIGINT"},"fieldName":"row_number"}]})"; + request->queryId() = "query1"; + + ExecutePlanResponse response; + handler.execute(response, std::move(request)); + + EXPECT_TRUE(*response.success()); + EXPECT_EQ(response.results()->size(), 1); + + const auto& batch = (*response.results()).front(); + EXPECT_EQ(batch.columnNames()->size(), 2); + EXPECT_EQ((*batch.columnNames())[0], "p0"); + EXPECT_EQ(batch.columnTypes()->size(), 2); + EXPECT_EQ((*batch.columnTypes())[0], "DOUBLE"); + EXPECT_EQ(batch.numRows(), 1); + EXPECT_EQ(batch.columns()->size(), 2); + + const auto& column = (*batch.columns())[0]; + EXPECT_EQ(column.rows()->size(), 1); + const auto& row = (*column.rows())[0]; + EXPECT_TRUE(*row.isNull()); + EXPECT_FALSE(row.scalarValue_ref().has_value()); +} + +TEST_F(LocalRunnerServiceTest, ServiceHandlerMockRequestIntegrationFailure) { + LocalRunnerServiceHandler handler; + + auto request = std::make_unique(); + // Serialized plan for the following: + // expressions: (p0:TINYINT, divide(89,"c0") + // Will encounter divide by zero error. + request->serializedPlan() = + R"({"projections":[{"inputs":[{"valueVector":"AQAAACAAAAB7InR5cGUiOiJUSU5ZSU5UIiwibmFtZSI6IlR5cGUifQEAAAAAAVk=","type":{"type":"TINYINT","name":"Type"},"name":"ConstantTypedExpr"},{"fieldName":"c0","type":{"type":"TINYINT","name":"Type"},"name":"FieldAccessTypedExpr"}],"functionName":"divide","type":{"type":"TINYINT","name":"Type"},"name":"CallTypedExpr"},{"fieldName":"row_number","type":{"type":"BIGINT","name":"Type"},"name":"FieldAccessTypedExpr"}],"sources":[{"projections":[{"inputs":[{"type":{"cTypes":[{"type":"TINYINT","name":"Type"},{"type":"BIGINT","name":"Type"}],"names":["c0","row_number"],"type":"ROW","name":"Type"},"name":"InputTypedExpr"}],"fieldName":"c0","type":{"type":"TINYINT","name":"Type"},"name":"FieldAccessTypedExpr"},{"inputs":[{"type":{"cTypes":[{"type":"TINYINT","name":"Type"},{"type":"BIGINT","name":"Type"}],"names":["c0","row_number"],"type":"ROW","name":"Type"},"name":"InputTypedExpr"}],"fieldName":"row_number","type":{"type":"BIGINT","name":"Type"},"name":"FieldAccessTypedExpr"}],"sources":[{"parallelizable":false,"repeatTimes":1,"data":"AAAAAIQAAAB7ImNUeXBlcyI6W3sidHlwZSI6IlRJTllJTlQiLCJuYW1lIjoiVHlwZSJ9LHsidHlwZSI6IkJJR0lOVCIsIm5hbWUiOiJUeXBlIn1dLCJuYW1lcyI6WyJjMCIsInJvd19udW1iZXIiXSwidHlwZSI6IlJPVyIsIm5hbWUiOiJUeXBlIn0KAAAAAAIAAAABAgAAACAAAAB7InR5cGUiOiJUSU5ZSU5UIiwibmFtZSI6IlR5cGUifQoAAAAAKAAAAAMAAAACAAAABgAAAAAAAAABAAAACAAAAAUAAAAAAAAACAAAAAUAAAACAAAAIAAAAHsidHlwZSI6IlRJTllJTlQiLCJuYW1lIjoiVHlwZSJ9CgAAAAECAAAA9/8oAAAACQAAAAQAAAAJAAAAAAAAAAYAAAAHAAAABAAAAAYAAAAAAAAAAAAAAAIAAAAgAAAAeyJ0eXBlIjoiVElOWUlOVCIsIm5hbWUiOiJUeXBlIn0KAAAAAQIAAAD7oigAAAAJAAAAAQAAAAkAAAAHAAAAAAAAAAUAAAAEAAAAAwAAAAEAAAAAAAAAAAAAACAAAAB7InR5cGUiOiJUSU5ZSU5UIiwibmFtZSI6IlR5cGUifQoAAAAAAQoAAABTOkYvJBw5ZUAAAQAAAAAfAAAAeyJ0eXBlIjoiQklHSU5UIiwibmFtZSI6IlR5cGUifQoAAAAAAVAAAAAAAAAAAAAAAAEAAAAAAAAAAgAAAAAAAAADAAAAAAAAAAQAAAAAAAAABQAAAAAAAAAGAAAAAAAAAAcAAAAAAAAACAAAAAAAAAAJAAAAAAAAAA==","id":"d69f11dc_1f0e_40ae_8c5d_2cde4b784a12","name":"ValuesNode"}],"names":["c0","row_number"],"id":"transform","name":"ProjectNode"}],"names":["p0","p1"],"id":"project","name":"ProjectNode"})"; + request->queryId() = "query1"; + + ExecutePlanResponse response; + handler.execute(response, std::move(request)); + + ASSERT_TRUE(response.errorMessage().has_value()); + auto errorMsg = response.errorMessage().value(); + EXPECT_NE(errorMsg.find("Error Source: USER"), std::string::npos); + EXPECT_NE(errorMsg.find("Error Code: ARITHMETIC_ERROR"), std::string::npos); + EXPECT_NE(errorMsg.find("Reason: division by zero"), std::string::npos); + + EXPECT_FALSE(*response.success()); + EXPECT_EQ(response.results()->size(), 0); +} + +} // namespace facebook::velox::fuzzer::test + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + folly::Init init(&argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/velox/exec/prefixsort/CMakeLists.txt b/velox/exec/prefixsort/CMakeLists.txt index aac7db501ecd..ccaca96a1c46 100644 --- a/velox/exec/prefixsort/CMakeLists.txt +++ b/velox/exec/prefixsort/CMakeLists.txt @@ -21,3 +21,5 @@ endif() if(${VELOX_ENABLE_BENCHMARKS}) add_subdirectory(benchmarks) endif() + +velox_install_library_headers() diff --git a/velox/exec/prefixsort/PrefixSortEncoder.h b/velox/exec/prefixsort/PrefixSortEncoder.h index 945b1de5e160..87756792a726 100644 --- a/velox/exec/prefixsort/PrefixSortEncoder.h +++ b/velox/exec/prefixsort/PrefixSortEncoder.h @@ -28,7 +28,7 @@ namespace facebook::velox::exec::prefixsort { class PrefixSortEncoder { public: PrefixSortEncoder(bool ascending, bool nullsFirst) - : ascending_(ascending), nullsFirst_(nullsFirst){}; + : ascending_(ascending), nullsFirst_(nullsFirst) {} /// Encode native primitive types(such as uint64_t, int64_t, uint32_t, /// int32_t, uint16_t, int16_t, float, double, Timestamp). diff --git a/velox/exec/tests/AggregationTest.cpp b/velox/exec/tests/AggregationTest.cpp index 58b6d58b443a..6ab71cad302d 100644 --- a/velox/exec/tests/AggregationTest.cpp +++ b/velox/exec/tests/AggregationTest.cpp @@ -550,47 +550,6 @@ TEST_F(AggregationTest, missingLambdaFunction) { readCursor(params), "Aggregate function not registered: missing-lambda"); } -TEST_F(AggregationTest, DISABLED_resultTypeMismatch) { - using Step = core::AggregationNode::Step; - - registerAggregateFunction( - "test_aggregate", - {AggregateFunctionSignatureBuilder() - .returnType("bigint") - .intermediateType("bigint") - .argumentType("bigint") - .build()}, - [&](Step /*step*/, - const std::vector& /*argTypes*/, - const TypePtr& /*resultType*/, - const core::QueryConfig& /*config*/) - -> std::unique_ptr { VELOX_UNREACHABLE(); }, - false /*registerCompanionFunctions*/, - true /*overwrite*/); - - for (auto step : {Step::kIntermediate, Step::kPartial}) { - VELOX_ASSERT_THROW( - Aggregate::create( - "test_aggregate", - step, - std::vector{BIGINT()}, - INTEGER(), - core::QueryConfig{{}}), - "Intermediate type mismatch"); - } - - for (auto step : {Step::kFinal, Step::kSingle}) { - VELOX_ASSERT_THROW( - Aggregate::create( - "test_aggregate", - step, - std::vector{BIGINT()}, - INTEGER(), - core::QueryConfig{{}}), - "Final type mismatch"); - } -} - TEST_F(AggregationTest, global) { auto vectors = makeVectors(rowType_, 10, 100); createDuckDbTable(vectors); diff --git a/velox/exec/tests/AsyncConnectorTest.cpp b/velox/exec/tests/AsyncConnectorTest.cpp index 3904a610fd69..260ed4c49288 100644 --- a/velox/exec/tests/AsyncConnectorTest.cpp +++ b/velox/exec/tests/AsyncConnectorTest.cpp @@ -126,7 +126,7 @@ class TestDataSource : public connector::DataSource { return 0; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { return {}; } diff --git a/velox/exec/tests/CMakeLists.txt b/velox/exec/tests/CMakeLists.txt index 010f343dea05..25f6e9f87f6f 100644 --- a/velox/exec/tests/CMakeLists.txt +++ b/velox/exec/tests/CMakeLists.txt @@ -76,6 +76,7 @@ add_executable( ParallelProjectTest.cpp PartitionedOutputTest.cpp PlanNodeSerdeTest.cpp + PlanNodeStatsTest.cpp PlanNodeToStringTest.cpp PlanNodeToSummaryStringTest.cpp PrefixSortTest.cpp @@ -88,6 +89,7 @@ add_executable( ScaledScanControllerTest.cpp ScaleWriterLocalPartitionTest.cpp SortBufferTest.cpp + SpatialIndexTest.cpp SpillerTest.cpp SpillTest.cpp SplitListenerTest.cpp @@ -229,6 +231,7 @@ add_executable( PrestoQueryRunnerQDigestTransformTest.cpp PrestoQueryRunnerJsonTransformTest.cpp PrestoQueryRunnerIntervalTransformTest.cpp + PrestoQueryRunnerTimeTransformTest.cpp PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp ) diff --git a/velox/exec/tests/DriverTest.cpp b/velox/exec/tests/DriverTest.cpp index 275ed3bd3455..f232b73f513a 100644 --- a/velox/exec/tests/DriverTest.cpp +++ b/velox/exec/tests/DriverTest.cpp @@ -1614,7 +1614,8 @@ DEBUG_ONLY_TEST_F(DriverTest, driverCpuTimeSlicingCheck) { 0, core::QueryCtx::create( driverExecutor_.get(), core::QueryConfig{std::move(queryConfig)}), - testParam.executionMode); + testParam.executionMode, + exec::Consumer{}); while (task->next() != nullptr) { } } diff --git a/velox/exec/tests/ExchangeClientTest.cpp b/velox/exec/tests/ExchangeClientTest.cpp index f8a6a374eba6..b5ae077fa837 100644 --- a/velox/exec/tests/ExchangeClientTest.cpp +++ b/velox/exec/tests/ExchangeClientTest.cpp @@ -91,7 +91,8 @@ class ExchangeClientTest core::PlanFragment{plan}, 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); } int32_t enqueue( diff --git a/velox/exec/tests/FunctionSignatureBuilderTest.cpp b/velox/exec/tests/FunctionSignatureBuilderTest.cpp index 6ac14b0592ad..c792fef36a89 100644 --- a/velox/exec/tests/FunctionSignatureBuilderTest.cpp +++ b/velox/exec/tests/FunctionSignatureBuilderTest.cpp @@ -124,7 +124,7 @@ TEST_F(FunctionSignatureBuilderTest, typeParamTests) { .returnType("integer") .argumentType("row(..., varchar)") .build(), - "Failed to parse type signature [row(..., varchar)]: syntax error, unexpected COMMA"); + "Failed to parse type signature [row(..., varchar)]: syntax error, unexpected ELLIPSIS"); // Type params cant have type params. VELOX_ASSERT_THROW( @@ -155,6 +155,23 @@ TEST_F(FunctionSignatureBuilderTest, anyInReturn) { "Type 'Any' cannot appear in return type"); } +TEST_F(FunctionSignatureBuilderTest, homogeneousRowInReturn) { + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .typeVariable("T") + .returnType("row(T, ...)") + .argumentType("T") + .build(), + "Homogeneous row cannot appear in return type"); + + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .returnType("array(row(bigint, ...))") + .argumentType("bigint") + .build(), + "Homogeneous row cannot appear in return type"); +} + TEST_F(FunctionSignatureBuilderTest, scalarConstantFlags) { { auto signature = FunctionSignatureBuilder() diff --git a/velox/exec/tests/GroupedExecutionTest.cpp b/velox/exec/tests/GroupedExecutionTest.cpp index 07fc21ba3f7d..6a13989ba3ca 100644 --- a/velox/exec/tests/GroupedExecutionTest.cpp +++ b/velox/exec/tests/GroupedExecutionTest.cpp @@ -675,16 +675,25 @@ DEBUG_ONLY_TEST_F( } })); + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + std::optional spillOpts; + if (testData.enableSpill) { + spillOpts = common::SpillDiskOptions{ + .spillDirPath = spillDirectory->getPath(), + .spillDirCreated = true, + .spillDirCreateCb = nullptr}; + } + auto task = exec::Task::create( "0", std::move(planFragment), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - if (testData.enableSpill) { - task->setSpillDirectory(spillDirectory->getPath()); - } + Task::ExecutionMode::kParallel, + /*consumer=*/Consumer{}, + /*memoryArbitrationPriority=*/0, + spillOpts, + /*onError=*/nullptr); // 'numDriversPerGroup' drivers max to execute one group at a time. task->start(numDriversPerGroup, testData.groupConcurrency); @@ -817,15 +826,21 @@ DEBUG_ONLY_TEST_F( memory::testingRunArbitration(op->pool()); })); + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + common::SpillDiskOptions spillOpts{ + .spillDirPath = spillDirectory->getPath(), + .spillDirCreated = true, + .spillDirCreateCb = nullptr}; + auto task = exec::Task::create( "0", std::move(planFragment), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - - task->setSpillDirectory(spillDirectory->getPath()); + Task::ExecutionMode::kParallel, + Consumer{}, + /*memoryArbitrationPriority=*/0, + spillOpts); // 'numDriversPerGroup' drivers max to execute one group at a time. task->start(numDriversPerGroup, 1); @@ -848,8 +863,8 @@ DEBUG_ONLY_TEST_F( } // Total drivers should be numDriversPerGroup * (numGroups + 1), but since - // probe does not receive termination signal, it cannot signal the build side - // to finish. we expect only build's numDriversPerGroup finished. + // probe does not receive termination signal, it cannot signal the build + // side to finish. we expect only build's numDriversPerGroup finished. waitForFinishedDrivers(task, numDriversPerGroup); // 'Delete results' from output buffer triggers 'set all output consumed', diff --git a/velox/exec/tests/IndexLookupJoinTest.cpp b/velox/exec/tests/IndexLookupJoinTest.cpp index f1e056f117e6..9305bf75c602 100644 --- a/velox/exec/tests/IndexLookupJoinTest.cpp +++ b/velox/exec/tests/IndexLookupJoinTest.cpp @@ -15,6 +15,7 @@ */ #include "velox/exec/IndexLookupJoin.h" +#include "fmt/format.h" #include "folly/experimental/EventCount.h" #include "gmock/gmock.h" #include "gtest/gtest-matchers.h" @@ -224,14 +225,13 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { auto plan = PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0"}, - {"u0"}, - indexTableScan, - {}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}, - joinType) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); @@ -246,14 +246,14 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) - .indexLookupJoin( - {"t0"}, - {"u0"}, - indexTableScan, - {"contains(t3, u0)", "contains(t4, u1)"}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}, - joinType) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t3, u0)", "contains(t4, u1)"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); @@ -265,6 +265,79 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { } // with between join conditions. + for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { + auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions( + {"u0 between t0 AND t1", + "u1 between t1 AND 10", + "u1 between 10 AND t1"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 3); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with mix join conditions. + for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { + auto plan = + PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t3, u0)", "u1 between 10 AND t1"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with has match column. + { + auto plan = + PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t3, u0)", "u1 between 10 AND t1"}) + .hasMarker(true) + .outputLayout({"t0", "u1", "t2", "t1", "match"}) + .joinType(core::JoinType::kLeft) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with filter. for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) @@ -272,23 +345,25 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { {"t0"}, {"u0"}, indexTableScan, - {"u0 between t0 AND t1", - "u1 between t1 AND 10", - "u1 between 10 AND t1"}, - /*includeMatchColumn=*/false, + {}, + /*filter=*/"t1 % 2 = 0", + /*hasMarker=*/false, {"t0", "u1", "t2", "t1"}, joinType) .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); - ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 3); + ASSERT_TRUE(indexLookupJoinNode->joinConditions().empty()); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), "eq(mod(ROW[\"t1\"],2),0)"); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), kTestIndexConnectorName); testSerde(plan); } - // with mix join conditions. + // with join conditions and filter. for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) @@ -296,21 +371,26 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { {"t0"}, {"u0"}, indexTableScan, - {"contains(t3, u0)", "u1 between 10 AND t1"}, - /*includeMatchColumn=*/false, + {"contains(t3, u0)"}, + /*filter=*/"u1 % 2 = 0 AND t2 > 5", + /*hasMarker=*/false, {"t0", "u1", "t2", "t1"}, joinType) .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); - ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 1); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), + "and(eq(mod(ROW[\"u1\"],2),0),gt(ROW[\"t2\"],5))"); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), kTestIndexConnectorName); testSerde(plan); } - // with has match column. + // with filter and marker for left join. { auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) @@ -318,14 +398,46 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { {"t0"}, {"u0"}, indexTableScan, - {"contains(t3, u0)", "u1 between 10 AND t1"}, - /*includeMatchColumn=*/true, + {"u1 between 10 AND t1"}, + /*filter=*/"t2 < u2", + /*hasMarker=*/true, {"t0", "u1", "t2", "t1", "match"}, core::JoinType::kLeft) .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); - ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 1); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), + "lt(ROW[\"t2\"],ROW[\"u2\"])"); + ASSERT_TRUE(indexLookupJoinNode->hasMarker()); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with complex filter expression. + { + auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .filter("(t1 + u1) * 2 > 100 OR t2 = u2") + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(core::JoinType::kInner) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_TRUE(indexLookupJoinNode->joinConditions().empty()); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), + "or(gt(multiply(plus(ROW[\"t1\"],ROW[\"u1\"]),2),100),eq(ROW[\"t2\"],ROW[\"u2\"]))"); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), kTestIndexConnectorName); @@ -337,14 +449,13 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_USER_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0"}, - {"u0"}, - indexTableScan, - {}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}, - core::JoinType::kFull) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(core::JoinType::kFull) + .endIndexLookupJoin() .planNode(), "Unsupported index lookup join type FULL"); } @@ -354,13 +465,12 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_USER_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0"}, - {"u0"}, - nonIndexTableScan, - {}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(nonIndexTableScan) + .outputLayout({"t0", "u1", "t2", "t1"}) + .endIndexLookupJoin() .planNode(), "The lookup table handle hive_table from connector test-hive doesn't support index lookup"); } @@ -370,13 +480,13 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0", "t1"}, - {"u0"}, - indexTableScan, - {"contains(t4, u0)"}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}) + .startIndexLookupJoin() + .leftKeys({"t0", "t1"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t4, u0)"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .endIndexLookupJoin() .planNode(), "The index lookup join node requires same number of join keys on left and right sides"); } @@ -386,13 +496,13 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {}, - {}, - indexTableScan, - {"contains(t4, u0)"}, - /*includeMatchColumn=*/false, - {"t0", "u1", "t2", "t1"}) + .startIndexLookupJoin() + .leftKeys({}) + .rightKeys({}) + .indexSource(indexTableScan) + .joinConditions({"contains(t4, u0)"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .endIndexLookupJoin() .planNode(), "The index lookup join node requires at least one join key"); } @@ -789,7 +899,8 @@ TEST_P(IndexLookupJoinTest, equalJoin) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -810,7 +921,8 @@ TEST_P(IndexLookupJoinTest, equalJoin) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( @@ -1262,7 +1374,8 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { {"t0", "t1"}, {"u0", "u1"}, {testData.betweenCondition}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -1283,7 +1396,8 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { {"t0", "t1"}, {"u0", "u1"}, {testData.betweenCondition}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( @@ -1601,7 +1715,8 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { {"t0", "t1"}, {"u0", "u1"}, {testData.inCondition}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -1622,7 +1737,8 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { {"t0", "t1"}, {"u0", "u1"}, {testData.inCondition}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( @@ -1756,7 +1872,8 @@ TEST_P(IndexLookupJoinTest, prefixKeysEqualJoin) { leftKeys, rightKeys, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -1777,7 +1894,8 @@ TEST_P(IndexLookupJoinTest, prefixKeysEqualJoin) { leftKeys, rightKeys, {}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( @@ -1894,7 +2012,8 @@ TEST_P(IndexLookupJoinTest, prefixKeysbetweenJoinCondition) { {"t0"}, {"u0"}, {testData.betweenCondition}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -1915,7 +2034,8 @@ TEST_P(IndexLookupJoinTest, prefixKeysbetweenJoinCondition) { {"t0"}, {"u0"}, {testData.betweenCondition}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( @@ -2033,7 +2153,8 @@ TEST_P(IndexLookupJoinTest, prefixInJoinCondition) { {"t0"}, {"u0"}, {testData.inCondition}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -2054,7 +2175,8 @@ TEST_P(IndexLookupJoinTest, prefixInJoinCondition) { {"t0"}, {"u0"}, {testData.inCondition}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, testData.joinType, testData.outputColumns); verifyResultWithMatchColumn( @@ -2107,7 +2229,8 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, connectorError) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u0", "u1", "u2", "t5"}); VELOX_ASSERT_THROW( @@ -2176,7 +2299,8 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u3", "t5"}); std::thread queryThread([&] { @@ -2279,7 +2403,8 @@ TEST_P(IndexLookupJoinTest, outputBatchSizeWithInnerJoin) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"t4", "u5"}); const auto task = @@ -2387,7 +2512,8 @@ TEST_P(IndexLookupJoinTest, outputBatchSizeWithLeftJoin) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kLeft, {"t4", "u5"}); const auto task = @@ -2420,7 +2546,8 @@ TEST_P(IndexLookupJoinTest, outputBatchSizeWithLeftJoin) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, core::JoinType::kLeft, {"t4", "u5"}); verifyResultWithMatchColumn( @@ -2471,7 +2598,8 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, runtimeStats) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u3", "t5"}); auto task = runLookupQuery( @@ -2555,7 +2683,8 @@ TEST_P(IndexLookupJoinTest, barrier) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u3", "t5"}); @@ -2644,7 +2773,8 @@ TEST_P(IndexLookupJoinTest, nullKeys) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u3", "t5"}); @@ -2663,7 +2793,8 @@ TEST_P(IndexLookupJoinTest, nullKeys) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kLeft, {"u3", "t5"}); @@ -2683,7 +2814,8 @@ TEST_P(IndexLookupJoinTest, nullKeys) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/true, + /*filter=*/"", + /*hasMarker=*/true, core::JoinType::kLeft, {"u3", "t5"}); verifyResultWithMatchColumn( @@ -2722,7 +2854,8 @@ TEST_P(IndexLookupJoinTest, joinFuzzer) { {"t0"}, {"u0"}, {"contains(t4, u1)", "u2 between t1 and t2"}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u0", "u4", "t0", "t1", "t4"}); runLookupQuery( @@ -2773,7 +2906,8 @@ TEST_P(IndexLookupJoinTest, tableRowsWithDuplicateKeys) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, - /*includeMatchColumn=*/false, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, scanOutput); runLookupQuery( @@ -2785,6 +2919,379 @@ TEST_P(IndexLookupJoinTest, tableRowsWithDuplicateKeys) { GetParam().numPrefetches, "SELECT u.c0, u.c1, u.c2, u.c3, u.c4, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND u.c2 = t.c2"); } + +TEST_P(IndexLookupJoinTest, withFilter) { + struct { + std::vector keyCardinalities; + int numProbeBatches; + int numRowsPerProbeBatch; + int matchPct; + std::vector scanOutputColumns; + std::vector outputColumns; + core::JoinType joinType; + std::string filter; + std::string duckDbVerifySql; + + std::string debugString() const { + return fmt::format( + "keyCardinalities: {}, numProbeBatches: {}, numRowsPerProbeBatch: {}, matchPct: {}, " + "scanOutputColumns: {}, outputColumns: {}, joinType: {}, filter: {}, " + "duckDbVerifySql: {}", + folly::join(",", keyCardinalities), + numProbeBatches, + numRowsPerProbeBatch, + matchPct, + folly::join(",", scanOutputColumns), + folly::join(",", outputColumns), + core::JoinTypeName::toName(joinType), + filter, + duckDbVerifySql); + } + } testSettings[] = { + // Inner join with filter on probe side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c3 = t.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c3 != t.c3"}, + + // Inner join with filter on lookup side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 = u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 = u.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 != u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 != u.c3"}, + + // Inner join with filter on both side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 % 2 = 0 AND t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 % 2 = 0 AND t.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 = u3 AND t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 = u.c3 AND t.c3 = t.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 != u3 AND t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 != u.c3 AND t.c3 != t.c3"}, + + // Left join with filter on probe side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c3 % 2 = 0"}, + // Left join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c3 = t.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c3 != t.c3"}, + + // Left join with filter on lookup side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 = u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 = u.c3"}, + // Left join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 != u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 != u.c3"}, + + // Left join with filter on both side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 % 2 = 0 AND t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 % 2 = 0 AND t.c3 % 2 = 0"}, + // Left join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 = u3 AND t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 = u.c3 AND t.c3 = t.c3"}, + // Left join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 != u3 AND t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 != u.c3 AND t.c3 != t.c3"}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + SequenceTableData tableData; + generateIndexTableData(testData.keyCardinalities, tableData, pool_); + auto probeVectors = generateProbeInput( + testData.numProbeBatches, + testData.numRowsPerProbeBatch, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + testData.matchPct); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableData}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyData, + tableData.valueData, + *pool()); + const auto indexTableHandle = + makeIndexTableHandle(indexTable, GetParam().asyncLookup); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType(testData.scanOutputColumns), + makeIndexColumnHandles(testData.scanOutputColumns)); + + // Create a plan with filter + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + testData.filter, + /*hasMarker=*/false, + testData.joinType, + testData.outputColumns); + + runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 32, + GetParam().numPrefetches, + testData.duckDbVerifySql); + + if (testData.joinType != core::JoinType::kLeft) { + continue; + } + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + testData.filter, + /*hasMarker=*/true, + testData.joinType, + testData.outputColumns); + verifyResultWithMatchColumn( + plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + } +} + +TEST_P(IndexLookupJoinTest, mixedFilterBatches) { + // Create SequenceTableData using VectorTestBase utilities + SequenceTableData tableData; + + const std::string dummyString("test"); + StringView dummyStringView(dummyString); + // Create table key data (u0, u1, u2) using makeFlatVector + auto u0 = makeFlatVector(64, [&](auto row) { return row % 8; }); + auto u1 = makeFlatVector(64, [&](auto row) { return row % 8; }); + auto u2 = makeFlatVector(64, [&](auto row) { return row % 8; }); + tableData.keyData = makeRowVector({"u0", "u1", "u2"}, {u0, u1, u2}); + + // Create table value data (u3, u4, u5) using makeFlatVector + auto u3 = makeFlatVector(64, [&](auto row) { return row; }); + auto u4 = makeFlatVector(64, [&](auto row) { return row; }); + auto u5 = makeFlatVector( + 64, [&](auto /*unused*/) { return dummyStringView; }); + tableData.valueData = makeRowVector({"u3", "u4", "u5"}, {u3, u4, u5}); + + // Create complete table data by combining key and value data + tableData.tableData = makeRowVector( + {"u0", "u1", "u2", "u3", "u4", "u5"}, {u0, u1, u2, u3, u4, u5}); + + // Create probe vectors using makeArrayVectorFromJson in a loop + std::vector probeVectors; + probeVectors.reserve(5); + for (int i = 0; i < 5; ++i) { + probeVectors.push_back(makeRowVector( + {"t0", "t1", "t2", "t3", "t4", "t5"}, + {makeFlatVector(128, [&](auto row) { return row; }), + makeFlatVector(128, [&](auto row) { return row; }), + makeFlatVector(128, [&](auto row) { return row; }), + makeFlatVector(128, [&](auto row) { return row; }), + makeArrayVector( + 128, + [](vector_size_t /*unused*/) { return 1; }, + [](vector_size_t, vector_size_t) { return 1; }), + makeFlatVector( + 128, [&](auto /*unused*/) { return dummyStringView; })})); + } + + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableData}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); + const auto indexTableHandle = + makeIndexTableHandle(indexTable, GetParam().asyncLookup); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + "t3 > 4", + /*hasMarker=*/false, + core::JoinType::kLeft, + {"t1", "u1", "u2", "u3", "u5"}); + + AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .config( + core::QueryConfig::kIndexLookupJoinMaxPrefetchBatches, + std::to_string(GetParam().numPrefetches)) + .config(core::QueryConfig::kPreferredOutputBatchRows, "4") + .config(core::QueryConfig::kIndexLookupJoinSplitOutput, "true") + .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) + .serialExecution(GetParam().serialExecution) + .barrierExecution(GetParam().serialExecution) + .assertResults( + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2 AND t.c3 > 4"); +} } // namespace VELOX_INSTANTIATE_TEST_SUITE_P( diff --git a/velox/exec/tests/MergeTest.cpp b/velox/exec/tests/MergeTest.cpp index 02f80fcd140b..843d524fc5f3 100644 --- a/velox/exec/tests/MergeTest.cpp +++ b/velox/exec/tests/MergeTest.cpp @@ -601,6 +601,38 @@ TEST_F(MergeTest, localMergeSpillPartialEmpty) { ASSERT_EQ(planStats.spilledRows, 120); } +DEBUG_ONLY_TEST_F(MergeTest, localMergeSpillWithException) { + std::vector vectors; + for (int32_t i = 0; i < 9; ++i) { + constexpr vector_size_t batchSize = 137; + auto c0 = makeFlatVector( + batchSize, [&](auto row) { return batchSize * i + row; }, nullEvery(5)); + auto c1 = makeFlatVector( + batchSize, [&](auto row) { return row; }, nullEvery(5)); + auto c2 = makeFlatVector( + batchSize, [](auto row) { return row * 0.1; }, nullEvery(11)); + auto c3 = makeFlatVector(batchSize, [](auto row) { + return StringView::makeInline(std::to_string(row)); + }); + vectors.push_back(makeRowVector({c0, c1, c2, c3})); + } + createDuckDbTable(vectors); + + for (auto i = 0; i < 11; ++i) { + std::atomic_int cnt{0}; + const auto errorMessage = "ConcatFilesSpillBatchStream::nextBatch fail"; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::ConcatFilesSpillBatchStream::nextBatch", + std::function([&](void* /*unused*/) { + if (cnt++ == i) { + VELOX_FAIL("ConcatFilesSpillBatchStream::nextBatch fail"); + } + })); + + VELOX_ASSERT_THROW(testSingleKeyWithSpill(vectors, "c0"), errorMessage); + } +} + DEBUG_ONLY_TEST_F(MergeTest, localMergeSmallBatch) { std::vector vectors; for (int32_t i = 0; i < 9; ++i) { @@ -735,7 +767,7 @@ DEBUG_ONLY_TEST_F(MergeTest, localMergeAbort) { })); SCOPED_TESTVALUE_SET( - "facebook::velox::exec::SpillMerger::asyncReadFromSpillFileStream", + "facebook::velox::exec::SpillMerger::readFromSpillFileStream", std::function([&](void* /*unused*/) { if (cnt++ == 2) { blocked = true; diff --git a/velox/exec/tests/MergerTest.cpp b/velox/exec/tests/MergerTest.cpp index e1845fc552e5..ab2bd8239961 100644 --- a/velox/exec/tests/MergerTest.cpp +++ b/velox/exec/tests/MergerTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" #include "velox/exec/Merge.h" #include "velox/exec/MergeSource.h" @@ -420,3 +421,37 @@ TEST_F(MergerTest, spillMerger) { checkResults(expectedResults, results); } } + +DEBUG_ONLY_TEST_F(MergerTest, spillMergerException) { + struct TestSetting { + size_t maxOutputRows; + size_t numSources; + size_t queueSize; + + std::string debugString() const { + return fmt::format( + "maxOutputRows:{}, numStreams:{}, queueSize:{}", + maxOutputRows, + numSources, + queueSize); + } + }; + + std::atomic_int cnt{0}; + const auto errorMessage = "ConcatFilesSpillBatchStream::nextBatch fail"; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::ConcatFilesSpillBatchStream::nextBatch", + std::function([&](void* /*unused*/) { + if (cnt++ == 11) { + VELOX_FAIL("ConcatFilesSpillBatchStream::nextBatch fail"); + } + })); + const auto numSources = 5; + const auto queueSize = 2; + const auto sources = createMergeSources(numSources, queueSize); + auto [inputs, filesGroup] = generateInputs(numSources, 16); + const auto spillMerger = + createSpillMerger(std::move(filesGroup), 100, queueSize); + spillMerger->start(); + VELOX_ASSERT_THROW(getOutputFromSpillMerger(spillMerger.get()), errorMessage); +} diff --git a/velox/exec/tests/MultiFragmentTest.cpp b/velox/exec/tests/MultiFragmentTest.cpp index f58cab54e27f..c72074c7dc5d 100644 --- a/velox/exec/tests/MultiFragmentTest.cpp +++ b/velox/exec/tests/MultiFragmentTest.cpp @@ -127,7 +127,9 @@ class MultiFragmentTest : public HiveConnectorTestBase, std::unordered_map& extraQueryConfigs, int destination = 0, Consumer consumer = nullptr, - int64_t maxMemory = memory::kMaxMemory) const { + int64_t maxMemory = memory::kMaxMemory, + const std::optional& diskSpillOpts = + std::nullopt) const { auto configCopy = configSettings_; for (const auto& [k, v] : extraQueryConfigs) { configCopy[k] = v; @@ -148,7 +150,9 @@ class MultiFragmentTest : public HiveConnectorTestBase, destination, std::move(queryCtx), Task::ExecutionMode::kParallel, - std::move(consumer)); + std::move(consumer), + /*memoryArbitrationPriority=*/0, + diskSpillOpts); } std::vector makeVectors(int count, int rowsPerVector) { @@ -917,11 +921,17 @@ TEST_P(MultiFragmentTest, mergeExchangeWithSpill) { .capturePlanNodeId(partitionNodeId) .planNode(); localMergeNodeIds.push_back(localMergeNodeId); - auto sortTask = - makeTask(sortTaskId, partialSortPlan, spillMergeConfigs, tasks.size()); spillDirectories.push_back(TempDirectoryPath::create()); - sortTask->setSpillDirectory( - spillDirectories[numPartialSortTasks]->getPath()); + common::SpillDiskOptions spillOpts; + spillOpts.spillDirPath = spillDirectories[numPartialSortTasks]->getPath(); + auto sortTask = makeTask( + sortTaskId, + partialSortPlan, + spillMergeConfigs, + tasks.size(), + /*consumer=*/nullptr, + memory::kMaxMemory, + spillOpts); tasks.push_back(sortTask); sortTask->start(4); @@ -2915,8 +2925,8 @@ TEST_P(MultiFragmentTest, mergeSmallBatchesInExchange) { } else { test(1, 1'000); test(1'000, 72); - test(10'000, 7); - test(100'000, 1); + test(10'000, 8); + test(100'000, 2); } } diff --git a/velox/exec/tests/OperatorUtilsTest.cpp b/velox/exec/tests/OperatorUtilsTest.cpp index 45da6804d670..e03172c71a5b 100644 --- a/velox/exec/tests/OperatorUtilsTest.cpp +++ b/velox/exec/tests/OperatorUtilsTest.cpp @@ -48,7 +48,8 @@ class OperatorUtilsTest : public OperatorTestBase { std::move(planFragment), 0, core::QueryCtx::create(executor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); driver_ = Driver::testingCreate(); driverCtx_ = std::make_unique(task_, 0, 0, 0, 0); driverCtx_->driver = driver_.get(); diff --git a/velox/exec/tests/OutputBufferManagerTest.cpp b/velox/exec/tests/OutputBufferManagerTest.cpp index f72957c9281e..e9abfb651403 100644 --- a/velox/exec/tests/OutputBufferManagerTest.cpp +++ b/velox/exec/tests/OutputBufferManagerTest.cpp @@ -109,7 +109,8 @@ class OutputBufferManagerTest : public testing::Test { std::move(planFragment), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); bufferManager_->initializeTask(task, kind, numDestinations, numDrivers); return task; diff --git a/velox/exec/tests/PlanBuilderTest.cpp b/velox/exec/tests/PlanBuilderTest.cpp index ca2a2c073f77..63dfd88b1e92 100644 --- a/velox/exec/tests/PlanBuilderTest.cpp +++ b/velox/exec/tests/PlanBuilderTest.cpp @@ -17,6 +17,7 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/core/Expressions.h" #include "velox/exec/WindowFunction.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/TestIndexStorageConnector.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" @@ -374,9 +375,10 @@ TEST_F(PlanBuilderTest, indexLookupJoinBuilder) { .rightKeys({"u0"}) .indexSource(rightScan) .joinConditions({"contains(t1, u1)"}) - .includeMatchColumn(false) + .hasMarker(false) .outputLayout({"t0", "u1"}) .joinType(core::JoinType::kInner) + .filter("t0 > 0") .endIndexLookupJoin() .planNode(); @@ -389,7 +391,61 @@ TEST_F(PlanBuilderTest, indexLookupJoinBuilder) { ASSERT_EQ(indexJoinNode->leftKeys()[0]->name(), "t0"); ASSERT_EQ(indexJoinNode->rightKeys()[0]->name(), "u0"); ASSERT_EQ(indexJoinNode->joinConditions().size(), 1); - ASSERT_FALSE(indexJoinNode->includeMatchColumn()); + ASSERT_FALSE(indexJoinNode->hasMarker()); + ASSERT_EQ(indexJoinNode->outputType()->names().size(), 2); + ASSERT_EQ(indexJoinNode->outputType()->names()[0], "t0"); + ASSERT_EQ(indexJoinNode->outputType()->names()[1], "u1"); + ASSERT_EQ(indexJoinNode->filter()->toString(), "gt(ROW[\"t0\"],0)"); +} + +TEST_F(PlanBuilderTest, insertTableHandleParameter) { + auto data = makeRowVector({makeFlatVector(10, folly::identity)}); + auto directory = "/some/test/directory"; + + // Lambda to create a plan with given insertableHandle and verify it + auto testInsertTableHandle = + [&](std::shared_ptr insertTableHandle) { + // Create a plan with insertTableHandle + auto planBuilder = PlanBuilder().values({data}).tableWrite( + directory, + {}, + 0, + {}, + {}, + dwio::common::FileFormat::DWRF, + {}, + PlanBuilder::kHiveDefaultConnectorId, + {}, + nullptr, + "", + common::CompressionKind_NONE, + nullptr, + false, + connector::CommitStrategy::kNoCommit, + insertTableHandle); + + // Verify the plan node has the correct insert Table Handle. + auto tableWriteNode = + std::dynamic_pointer_cast( + planBuilder.planNode()); + ASSERT_NE(tableWriteNode, nullptr); + ASSERT_EQ(tableWriteNode->insertTableHandle(), insertTableHandle); + }; + + auto rowType = ROW({"c0", "c1", "c2"}, {BIGINT(), INTEGER(), SMALLINT()}); + auto hiveHandle = HiveConnectorTestBase::makeHiveInsertTableHandle( + rowType->names(), + rowType->children(), + {rowType->names()[0]}, // partitionedBy + nullptr, // bucketProperty + HiveConnectorTestBase::makeLocationHandle( + "/path/to/test", + std::nullopt, + connector::hive::LocationHandle::TableType::kNew)); + + auto insertHandle = std::make_shared( + std::string(PlanBuilder::kHiveDefaultConnectorId), hiveHandle); + testInsertTableHandle(insertHandle); } } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PlanNodeStatsTest.cpp b/velox/exec/tests/PlanNodeStatsTest.cpp new file mode 100644 index 000000000000..2a59537eb430 --- /dev/null +++ b/velox/exec/tests/PlanNodeStatsTest.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/PlanNodeStats.h" +#include + +namespace facebook::velox::exec::test { + +TEST(PlanNodeStatsTest, exprStatsTotal) { + PlanNodeStats stats; + stats.expressionStats["foo"] = ExprStats{ + .timing = {.wallNanos = 1, .cpuNanos = 2}, + .numProcessedRows = 3, + .numProcessedVectors = 4}; + + PlanNodeStats total; + total += stats; + EXPECT_EQ(total.expressionStats["foo"], stats.expressionStats["foo"]); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerTimeTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerTimeTransformTest.cpp new file mode 100644 index 000000000000..0e72861ff20b --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerTimeTransformTest.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" + +namespace facebook::velox::exec::test { +namespace { + +class PrestoQueryRunnerTimeTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +// Test that TIME is recognized as an intermediate type that needs +// transformation +TEST_F(PrestoQueryRunnerTimeTransformTest, isIntermediateOnlyType) { + // Core test: TIME should be an intermediate type + ASSERT_TRUE(isIntermediateOnlyType(TIME())); + + // Complex types containing TIME should also be intermediate types + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(TIME()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARCHAR(), TIME()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(TIME(), VARCHAR()))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({TIME(), BIGINT()}))); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, roundTrip) { + // Test basic TIME values (no nulls, some nulls, all nulls) + std::vector> no_nulls{0, 3661000, 43200000, 86399999}; + test(makeNullableFlatVector(no_nulls, TIME())); + + std::vector> some_nulls{ + 0, 3661000, std::nullopt, 86399999}; + test(makeNullableFlatVector(some_nulls, TIME())); + + std::vector> all_nulls{ + std::nullopt, std::nullopt, std::nullopt}; + test(makeNullableFlatVector(all_nulls, TIME())); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, transformArray) { + auto input = makeNullableFlatVector( + std::vector>{ + 0, // 00:00:00.000 + 1000, // 00:00:01.000 + 3661000, // 01:01:01.000 + 43200000, // 12:00:00.000 (noon) + 86399999, // 23:59:59.999 + 3723456, // 01:02:03.456 + 45678901, // 12:41:18.901 + std::nullopt, + 72000000, // 20:00:00.000 + 36000000 // 10:00:00.000 + }, + TIME()); + testArray(input); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, transformMap) { + // keys can't be null for maps + auto keys = makeNullableFlatVector( + std::vector>{ + 0, // 00:00:00.000 + 3661000, // 01:01:01.000 + 43200000, // 12:00:00.000 + 86399999, // 23:59:59.999 + 36000000, // 10:00:00.000 + 72000000, // 20:00:00.000 + 1800000, // 00:30:00.000 + 7200000, // 02:00:00.000 + 64800000, // 18:00:00.000 + 32400000 // 09:00:00.000 + }, + TIME()); + + auto values = makeNullableFlatVector( + {100, 200, std::nullopt, 400, 500, std::nullopt, 700, 800, 900, 1000}, + BIGINT()); + + testMap(keys, values); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, transformRow) { + auto input = makeNullableFlatVector( + std::vector>{ + 0, // 00:00:00.000 + 3661000, // 01:01:01.000 + 43200000, // 12:00:00.000 + 86399999, // 23:59:59.999 + std::nullopt, + 36000000, // 10:00:00.000 + 72000000, // 20:00:00.000 + 1800000, // 00:30:00.000 + 7200000, // 02:00:00.000 + 64800000 // 18:00:00.000 + }, + TIME()); + testRow({input}, {"time_col"}); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrintPlanWithStatsTest.cpp b/velox/exec/tests/PrintPlanWithStatsTest.cpp index 73928fce6a54..7c8d1c6ff74b 100644 --- a/velox/exec/tests/PrintPlanWithStatsTest.cpp +++ b/velox/exec/tests/PrintPlanWithStatsTest.cpp @@ -48,6 +48,12 @@ void compareOutputs( for (; std::getline(iss, line);) { lineCount++; std::vector potentialLines; + if (expectedLineIndex >= expectedRegex.size()) { + ASSERT_FALSE(true) << "Output has more lines than expected." + << "\n Source: " << testName + << "\n Line number: " << lineCount + << "\n Unexpected Line: " << line; + } auto expectedLine = expectedRegex.at(expectedLineIndex++); while (!RE2::FullMatch(line, expectedLine.line)) { potentialLines.push_back(expectedLine.line); @@ -59,11 +65,18 @@ void compareOutputs( << "\n Expected Line one of: " << folly::join(",", potentialLines); } + if (expectedLineIndex >= expectedRegex.size()) { + ASSERT_FALSE(true) + << "Output did not match and no more patterns to check." + << "\n Source: " << testName << "\n Line number: " << lineCount + << "\n Line: " << line + << "\n Expected Line one of: " << folly::join(",", potentialLines); + } expectedLine = expectedRegex.at(expectedLineIndex++); } } for (int i = expectedLineIndex; i < expectedRegex.size(); i++) { - ASSERT_TRUE(expectedRegex[expectedLineIndex].optional); + ASSERT_TRUE(expectedRegex[i].optional); } } @@ -209,6 +222,7 @@ TEST_F(PrintPlanWithStatsTest, innerJoinWithTableScan) { true}, {" processedSplits[ ]+sum: 20, count: 1, min: 20, max: 20, avg: 20"}, {" processedStrides[ ]+sum: 20, count: 1, min: 20, max: 20, avg: 20"}, + {" processedUnits [ ]* sum: .+, count: .+, min: .+, max: .+"}, {" ramReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" readyPreloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", true}, @@ -218,6 +232,7 @@ TEST_F(PrintPlanWithStatsTest, innerJoinWithTableScan) { {" storageReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" unitLoadNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}, {" waitForPreloadSplitNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}, {" -- Project\\[1\\]\\[expressions: \\(u_c0:INTEGER, ROW\\[\"c0\"\\]\\), \\(u_c1:BIGINT, ROW\\[\"c1\"\\]\\)\\] -> u_c0:INTEGER, u_c1:BIGINT"}, {" Output: 100 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: 0B, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, @@ -303,6 +318,7 @@ TEST_F(PrintPlanWithStatsTest, partialAggregateWithTableScan) { {" prefetchBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" processedSplits [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, {" processedStrides [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedUnits [ ]* sum: .+, count: .+, min: .+, max: .+"}, {" preloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", true}, {" ramReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, @@ -313,7 +329,8 @@ TEST_F(PrintPlanWithStatsTest, partialAggregateWithTableScan) { {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" storageReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}}); + {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" unitLoadNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}}); } } @@ -353,13 +370,14 @@ TEST_F(PrintPlanWithStatsTest, tableWriterWithTableScan) { {" dataSourceLazyCpuNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceLazyInputBytes\\s+sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceLazyWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" dwrfWriterCount\\s+sum: .+, count: 1, min: .+, max: .+"}, {" numWrittenFiles\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningWallNanos\\s+sum: .+, count: 1, min: .+, max: .+, avg: .+"}, {" stripeSize\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" writeIOWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" writeIOWallNanos\\s+sum: .+, count: 1, min: .+, max: .+, avg: .+"}, {R"( -- TableScan\[0\]\[table: hive_table\] -> c0:BIGINT, c1:INTEGER, c2:SMALLINT, c3:REAL, c4:DOUBLE, c5:VARCHAR)"}, {R"( Input: 100 rows \(.+\), Output: 100 rows \(.+\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+))"}, {" connectorSplitSize[ ]* sum: .+, count: .+, min: .+, max: .+"}, @@ -377,6 +395,7 @@ TEST_F(PrintPlanWithStatsTest, tableWriterWithTableScan) { {" prefetchBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" processedSplits [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, {" processedStrides [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedUnits [ ]* sum: .+, count: .+, min: .+, max: .+"}, {" preloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", true}, {" ramReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, @@ -387,7 +406,8 @@ TEST_F(PrintPlanWithStatsTest, tableWriterWithTableScan) { {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" storageReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}}); + {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" unitLoadNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}}); } TEST_F(PrintPlanWithStatsTest, taskAPI) { diff --git a/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp b/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp index 2dda1e2e8688..056f0e521292 100644 --- a/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp +++ b/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp @@ -81,6 +81,10 @@ class TestExchangeController { if (holdBufferBytes_ == 0) { return; } + if (holdBuffer_ != nullptr) { + return; + } + holdPool_ = pool; holdBuffer_ = holdPool_->allocate(holdBufferBytes_); } @@ -227,11 +231,6 @@ class FakeSourceOperator : public SourceOperator { private: void initialize() override { Operator::initialize(); - - if (operatorCtx_->driverCtx()->driverId != 0) { - return; - } - testController_->maybeHoldBuffer(pool()); } diff --git a/velox/exec/tests/SpatialIndexTest.cpp b/velox/exec/tests/SpatialIndexTest.cpp new file mode 100644 index 000000000000..6faedbe00392 --- /dev/null +++ b/velox/exec/tests/SpatialIndexTest.cpp @@ -0,0 +1,201 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/SpatialIndex.h" +#include +#include + +using namespace ::testing; +using namespace facebook::velox::exec; + +namespace facebook::velox::exec::test { + +class SpatialIndexTest : public virtual testing::Test { + protected: + SpatialIndex index_; + + void makeIndex(std::vector envelopes) { + index_ = SpatialIndex(std::move(envelopes)); + } + + Envelope indexBounds() const { + return index_.bounds(); + } + + void assertQuery( + double minX, + double minY, + double maxX, + double maxY, + std::vector expected) const { + std::vector actual = + index_.query(Envelope::from(minX, minY, maxX, maxY)); + std::sort(actual.begin(), actual.end()); + std::sort(expected.begin(), expected.end()); + ASSERT_EQ(actual, expected); + } +}; + +TEST_F(SpatialIndexTest, testEnvelope) { + Envelope empty = Envelope::empty(); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_FALSE(Envelope::intersects(empty, empty)); + + Envelope point = + Envelope{.minX = 0, .minY = 0, .maxX = 0, .maxY = 0, .rowIndex = -1}; + ASSERT_FALSE(point.isEmpty()); + ASSERT_FALSE(Envelope::intersects(empty, point)); + ASSERT_TRUE(Envelope::intersects(point, point)); +} + +TEST_F(SpatialIndexTest, testEmptyIndex) { + makeIndex(std::vector{}); + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, std::numeric_limits::infinity()); + ASSERT_EQ(bounds.minY, std::numeric_limits::infinity()); + ASSERT_EQ(bounds.maxX, -std::numeric_limits::infinity()); + ASSERT_EQ(bounds.maxY, -std::numeric_limits::infinity()); + ASSERT_EQ(bounds.rowIndex, -1); + + assertQuery(0, 0, 1, 1, {}); +} + +TEST_F(SpatialIndexTest, testPointProbe) { + makeIndex(std::vector{ + Envelope{.minX = 1, .minY = 0, .maxX = 1, .maxY = 0, .rowIndex = 6}, + Envelope{.minX = 0, .minY = 0, .maxX = 0, .maxY = 0, .rowIndex = 5}, + Envelope{.minX = 0, .minY = 0, .maxX = 1, .maxY = 1, .rowIndex = 4}, + Envelope{.minX = -1, .minY = -1, .maxX = 0, .maxY = 0, .rowIndex = 3}, + Envelope{.minX = -1, .minY = -1, .maxX = 1, .maxY = 1, .rowIndex = 2}, + Envelope{.minX = 0.5, .minY = 0.5, .maxX = 1, .maxY = 1, .rowIndex = 1}, + }); + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, -1); + ASSERT_EQ(bounds.minY, -1); + ASSERT_EQ(bounds.maxX, 1); + ASSERT_EQ(bounds.maxY, 1); + ASSERT_EQ(bounds.rowIndex, -1); + + assertQuery(0, 0, 0, 0, {2, 3, 4, 5}); + assertQuery(0, 1, 0, 1, {2, 4}); +} + +TEST_F(SpatialIndexTest, testFloatImprecision) { + // Since the index casts doubles to floats then nudges the result, + // we should make sure that the index gives the right results on + // cases where the double doesn't have an exact float representation. + float float1 = 1.0f; + float float1Down = + std::nextafterf(float1, -std::numeric_limits::infinity()); + float float2 = 2.0f; + float float2Up = + std::nextafterf(float2, std::numeric_limits::infinity()); + + double baseMax = static_cast(float2); + double baseMaxUp = + std::nextafter(baseMax, std::numeric_limits::infinity()); + double baseMaxDown = + std::nextafter(baseMax, -std::numeric_limits::infinity()); + double baseMin = static_cast(float1); + double baseMinUp = + std::nextafter(baseMin, std::numeric_limits::infinity()); + double baseMinDown = + std::nextafter(baseMin, -std::numeric_limits::infinity()); + + makeIndex(std::vector{ + Envelope::from(baseMin, baseMin, baseMax, baseMax, 1), + Envelope::from(baseMinUp, baseMinUp, baseMaxUp, baseMaxUp, 2), + Envelope::from(baseMinDown, baseMinDown, baseMaxDown, baseMaxDown, 3), + }); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, float1Down); + ASSERT_EQ(bounds.minY, float1Down); + ASSERT_EQ(bounds.maxX, float2Up); + ASSERT_EQ(bounds.maxY, float2Up); + + assertQuery(2.1, 2.1, 2.1, 2.1, {}); + assertQuery(baseMin, baseMin, baseMin, baseMin, {1, 2, 3}); + assertQuery(baseMinDown, baseMinDown, baseMinDown, baseMinDown, {1, 2, 3}); + assertQuery(baseMinUp, baseMinUp, baseMinUp, baseMinUp, {1, 2, 3}); + assertQuery(baseMax, baseMax, baseMax, baseMax, {1, 2, 3}); + assertQuery(baseMaxDown, baseMaxDown, baseMaxDown, baseMaxDown, {1, 2, 3}); + assertQuery(baseMaxUp, baseMaxUp, baseMaxUp, baseMaxUp, {1, 2, 3}); +} + +TEST_F(SpatialIndexTest, testFloatImprecisionSubnormal) { + // Check that our bumping rules work for subnormal floats as well. + float subnormalFloatDown = + std::nextafterf(0.0, -std::numeric_limits::infinity()); + float subnormalFloatUp = + std::nextafterf(0.0, std::numeric_limits::infinity()); + + double subnormalDoubleDown = + std::nextafter(0.0, -std::numeric_limits::infinity()); + double subnormalDoubleUp = + std::nextafter(0.0, std::numeric_limits::infinity()); + + makeIndex(std::vector{ + Envelope::from(0.0, 0.0, 0.0, 0.0, 1), + Envelope::from( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + 2), + Envelope::from( + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + 3), + Envelope::from( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleUp, + subnormalDoubleUp, + 4), + }); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, subnormalFloatDown); + ASSERT_EQ(bounds.minY, subnormalFloatDown); + ASSERT_EQ(bounds.maxX, subnormalFloatUp); + ASSERT_EQ(bounds.maxY, subnormalFloatUp); + + assertQuery(0.1, 0.1, 0.1, 0.1, {}); + assertQuery(0.0, 0.0, 0.0, 0.0, {1, 2, 3, 4}); + assertQuery( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + {1, 2, 3, 4}); + assertQuery( + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + {1, 2, 3, 4}); + assertQuery( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleUp, + subnormalDoubleUp, + {1, 2, 3, 4}); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/SpatialJoinTest.cpp b/velox/exec/tests/SpatialJoinTest.cpp index 539f132c6e6b..b1144ae6a872 100644 --- a/velox/exec/tests/SpatialJoinTest.cpp +++ b/velox/exec/tests/SpatialJoinTest.cpp @@ -16,6 +16,7 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/core/PlanFragment.h" #include "velox/core/PlanNode.h" +#include "velox/core/QueryConfig.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -71,45 +72,25 @@ class SpatialJoinTest : public OperatorTestBase { core::JoinType joinType, const std::vector>& expectedLeftWkts, const std::vector>& expectedRightWkts) { - runTestWithDrivers( - probeWkts, - buildWkts, - predicate, - joinType, - expectedLeftWkts, - expectedRightWkts, - 1, - false); - runTestWithDrivers( - probeWkts, - buildWkts, - predicate, - joinType, - expectedLeftWkts, - expectedRightWkts, - 1, - true); - runTestWithDrivers( - probeWkts, - buildWkts, - predicate, - joinType, - expectedLeftWkts, - expectedRightWkts, - 4, - false); - runTestWithDrivers( - probeWkts, - buildWkts, - predicate, - joinType, - expectedLeftWkts, - expectedRightWkts, - 4, - true); + for (bool separateProbeBatches : {false, true}) { + for (size_t maxBatchSize : {1024, 1}) { + for (int32_t maxDrivers : {1, 4}) { + runTestWithConfig( + probeWkts, + buildWkts, + predicate, + joinType, + expectedLeftWkts, + expectedRightWkts, + maxDrivers, + maxBatchSize, + separateProbeBatches); + } + } + } } - void runTestWithDrivers( + void runTestWithConfig( const std::vector>& probeWkts, const std::vector>& buildWkts, const std::string& predicate, @@ -117,6 +98,7 @@ class SpatialJoinTest : public OperatorTestBase { const std::vector>& expectedLeftWkts, const std::vector>& expectedRightWkts, int32_t maxDrivers, + size_t maxBatchSize, bool separateBatches) { std::vector> probeWktsStr( probeWkts.begin(), probeWkts.end()); @@ -162,6 +144,9 @@ class SpatialJoinTest : public OperatorTestBase { .localPartition({}) .planNode(), predicate, + "left_g", + "right_g", + std::nullopt, {"left_g", "right_g"}, joinType) .project( @@ -169,10 +154,24 @@ class SpatialJoinTest : public OperatorTestBase { "ST_AsText(right_g) AS right_g"}) .planNode(); AssertQueryBuilder builder{plan}; - builder.maxDrivers(maxDrivers).assertResults({expectedRows}); + builder.maxDrivers(maxDrivers) + .config(core::QueryConfig::kPreferredOutputBatchRows, maxBatchSize) + .config(core::QueryConfig::kMaxOutputBatchRows, maxBatchSize) + .assertResults({expectedRows}); } }; -TEST_F(SpatialJoinTest, simpleSpatialJoin) { + +TEST_F(SpatialJoinTest, testTrivialSpatialJoin) { + runTest( + {"POINT (1 1)"}, + {"POINT (1 1)"}, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (1 1)"}, + {"POINT (1 1)"}); +} + +TEST_F(SpatialJoinTest, testSimpleSpatialInnerJoin) { runTest( {"POINT (1 1)", "POINT (1 2)"}, {"POINT (1 1)", "POINT (2 1)"}, @@ -180,6 +179,9 @@ TEST_F(SpatialJoinTest, simpleSpatialJoin) { core::JoinType::kInner, {"POINT (1 1)"}, {"POINT (1 1)"}); +} + +TEST_F(SpatialJoinTest, testSimpleSpatialLeftJoin) { runTest( {"POINT (1 1)", "POINT (1 2)"}, {"POINT (1 1)", "POINT (2 1)"}, @@ -189,7 +191,36 @@ TEST_F(SpatialJoinTest, simpleSpatialJoin) { {"POINT (1 1)", std::nullopt}); } -TEST_F(SpatialJoinTest, selfSpatialJoin) { +TEST_F(SpatialJoinTest, testSpatialJoinNullRows) { + runTest( + {"POINT (0 0)", std::nullopt, "POINT (1 1)", std::nullopt}, + {"POINT (0 0)", "POINT (1 1)", std::nullopt, std::nullopt}, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (0 0)", "POINT (1 1)"}, + {"POINT (0 0)", "POINT (1 1)"}); + runTest( + {"POINT (0 0)", std::nullopt, "POINT (2 2)", std::nullopt}, + {"POINT (0 0)", "POINT (1 1)", std::nullopt, std::nullopt}, + "ST_Intersects(left_g, right_g)", + core::JoinType::kLeft, + {"POINT (0 0)", "POINT (2 2)", std::nullopt, std::nullopt}, + {"POINT (0 0)", std::nullopt, std::nullopt, std::nullopt}); +} + +// Test geometries that don't intersect but their envelopes do. +// Important to test spatial index +TEST_F(SpatialJoinTest, simpleSpatialJoinEnvelopes) { + runTest( + {"POINT (0.5 0.6)", "POINT (0.5 0.5)", "LINESTRING (0 0.1, 0.9 1)"}, + {"POLYGON ((0 0, 1 1, 1 0, 0 0))"}, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (0.5 0.5)"}, + {"POLYGON ((0 0, 1 1, 1 0, 0 0))"}); +} + +TEST_F(SpatialJoinTest, testSelfSpatialJoin) { std::vector> inputWkts = { kPolygonA, kPolygonB, kPolygonC, kPolygonD}; std::vector> leftOutputWkts = { @@ -300,6 +331,16 @@ TEST_F(SpatialJoinTest, pointPolygonSpatialJoin) { polygonOutputWkts); } +TEST_F(SpatialJoinTest, simpleNullRowsJoin) { + runTest( + {"POINT (1 1)", std::nullopt, "POINT (1 2)"}, + {"POINT (1 1)", "POINT (2 1)", std::nullopt}, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (1 1)"}, + {"POINT (1 1)"}); +} + TEST_F(SpatialJoinTest, failOnGroupedExecution) { std::vector batches{ makeRowVector({"wkt"}, {makeFlatVector({"POINT(0 0)"})})}; @@ -317,6 +358,9 @@ TEST_F(SpatialJoinTest, failOnGroupedExecution) { .localPartition({}) .planNode(), "ST_Intersects(left_g, right_g)", + "left_g", + "right_g", + std::nullopt, {"left_g", "right_g"}, core::JoinType::kInner) .project( diff --git a/velox/exec/tests/TableEvolutionFuzzer.cpp b/velox/exec/tests/TableEvolutionFuzzer.cpp index a353cbe5057f..d92649050e09 100644 --- a/velox/exec/tests/TableEvolutionFuzzer.cpp +++ b/velox/exec/tests/TableEvolutionFuzzer.cpp @@ -70,6 +70,14 @@ VectorFuzzer::Options makeVectorFuzzerOptions() { return options; } +template +void removeFromVector(std::vector& vec, const T& value) { + auto it = std::find(vec.begin(), vec.end(), value); + if (it != vec.end()) { + vec.erase(it); + } +} + bool hasUnsupportedMapKey(const TypePtr& type) { switch (type->kind()) { case TypeKind::MAP: { @@ -219,6 +227,7 @@ std::vector> runTaskCursors( } std::vector> results; constexpr std::chrono::seconds kTaskTimeout(10); + results.reserve(futures.size()); for (auto& future : futures) { results.push_back(std::move(future).get(kTaskTimeout)); } @@ -559,12 +568,17 @@ void TableEvolutionFuzzer::run() { 2 * config_.evolutionCount - 1); RowVectorPtr finalExpectedData; + folly::F14FastMap> globalMapColumnKeys; + std::vector globallyConsistentColumnIndexVector; + createWriteTasks( testSetups, bucketColumnIndices, tableOutputRootDir->getPath(), writeTasks, - finalExpectedData); + finalExpectedData, + globalMapColumnKeys, + globallyConsistentColumnIndexVector); auto executor = folly::getGlobalCPUExecutor(); auto writeResults = runTaskCursors(writeTasks, *executor); @@ -618,10 +632,20 @@ void TableEvolutionFuzzer::run() { } std::vector> scanTasks(2); - scanTasks[0] = - makeScanTask(rowType, std::move(actualSplits), pushownConfig, false); - scanTasks[1] = - makeScanTask(rowType, std::move(expectedSplits), pushownConfig, true); + scanTasks[0] = makeScanTask( + rowType, + std::move(actualSplits), + pushownConfig, + false, + globalMapColumnKeys, + globallyConsistentColumnIndexVector); + scanTasks[1] = makeScanTask( + rowType, + std::move(expectedSplits), + pushownConfig, + true, + globalMapColumnKeys, + globallyConsistentColumnIndexVector); ScopedOOMInjector oomInjectorReadPath( [this]() -> bool { return folly::Random::oneIn(10, rng_); }, @@ -799,7 +823,9 @@ std::unique_ptr TableEvolutionFuzzer::makeWriteTask( const std::string& outputDir, const std::vector& bucketColumnIndices, FuzzerGenerator& rng, - bool enableFlatMap) { + bool enableFlatMap, + folly::F14FastMap>& globalMapColumnKeys, + std::vector& globallyCompatibleFlatmapColumns) { auto builder = PlanBuilder().values({data}); // Create serdeParameters using proper dwrf::Config for flatmap configuration @@ -813,6 +839,7 @@ std::unique_ptr TableEvolutionFuzzer::makeWriteTask( if (setup.schema->childAt(i)->isMap()) { // Check if this specific map column has any empty elements if (hasEmptyElement(data, i)) { + removeFromVector(globallyCompatibleFlatmapColumns, i); continue; } @@ -822,7 +849,76 @@ std::unique_ptr TableEvolutionFuzzer::makeWriteTask( supportedMapColumnIndices.push_back(static_cast(i)); VLOG(1) << "Write column " << setup.schema->nameOf(i) << " as flatmap"; + + // Extract actual keys from the map data and collect directly into + // global set + SelectivityVector allRows(data->childAt(i)->size()); + DecodedVector decodedMap(*data->childAt(i), allRows); + auto* mapVector = decodedMap.base()->asChecked(); + if (mapVector->size() > 0) { + auto keys = mapVector->mapKeys(); + + if (keys) { + // Collect keys directly into the global set + auto& uniqueKeys = globalMapColumnKeys[static_cast(i)]; + + // Iterate through the decoded rows, not the raw mapVector + // indices + for (vector_size_t row = 0; row < data->childAt(i)->size(); + ++row) { + auto decodedIndex = decodedMap.index(row); + if (!decodedMap.isNullAt(row) && + !mapVector->isNullAt(decodedIndex)) { + // Get the map entry for this decoded row + auto mapOffset = mapVector->offsetAt(decodedIndex); + auto mapSize = mapVector->sizeAt(decodedIndex); + + // Process all keys in this map entry + for (vector_size_t keyIdx = 0; keyIdx < mapSize; ++keyIdx) { + auto keyPosition = mapOffset + keyIdx; + if (!keys->isNullAt(keyPosition)) { + std::string keyStr; + if (keys->type()->isVarchar() || + keys->type()->isVarbinary()) { + auto* keyVector = keys->asFlatVector(); + auto keyView = keyVector->valueAt(keyPosition); + keyStr = std::string(keyView); + } else if (keys->type()->isInteger()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else if (keys->type()->isBigint()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else if (keys->type()->isSmallint()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else if (keys->type()->isTinyint()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else { + // This should not be reached since + // hasUnsupportedMapKey filters out unsupported types + VELOX_UNREACHABLE( + "Unsupported map key type: {}", + keys->type()->toString()); + } + uniqueKeys.insert(keyStr); + } + } + } + } + } + } + } else { + // Remove this column from globallyCompatibleFlatmapColumns + removeFromVector(globallyCompatibleFlatmapColumns, i); } + } else { + removeFromVector(globallyCompatibleFlatmapColumns, i); } } } @@ -906,22 +1002,79 @@ VectorPtr TableEvolutionFuzzer::liftToPrimitiveType( std::vector({})); } +RowTypePtr TableEvolutionFuzzer::buildFlatmapAsStructSchema( + const RowTypePtr& tableSchema, + const folly::F14FastMap>& + globalMapColumnKeys, + const std::vector& globallyCompatibleFlatmapColumns) { + if (globallyCompatibleFlatmapColumns.empty()) { + return tableSchema; + } + + VLOG(1) << "Setting up struct reading for " + << globallyCompatibleFlatmapColumns.size() + << " flatmap columns with real keys"; + + auto names = tableSchema->names(); + auto types = tableSchema->children(); + + // Filter globalMapColumnKeys to only include globally compatible columns + std::unordered_map> filteredMapColumnKeys; + for (int mapColumnIndex : globallyCompatibleFlatmapColumns) { + if (globalMapColumnKeys.find(mapColumnIndex) != globalMapColumnKeys.end()) { + // Add 50% probability to include this column in filteredMapColumnKeys + if (folly::Random::oneIn(2, rng_)) { + filteredMapColumnKeys[mapColumnIndex] = + globalMapColumnKeys.at(mapColumnIndex); + } + } + } + + // Use the filteredMapColumnKeys for struct reading + for (const auto& [mapColumnIndex, keysSet] : filteredMapColumnKeys) { + // Convert map type to struct type for struct reading + auto finalMapType = types[mapColumnIndex]->asMap(); + auto finalValueType = finalMapType.valueType(); + // Convert F14FastSet to vector for ROW constructor + std::vector keys(keysSet.begin(), keysSet.end()); + // Construct struct schema with real keys from write time + final value + // type + std::vector finalStructFieldTypes(keys.size(), finalValueType); + auto finalStructSchema = ROW(keys, finalStructFieldTypes); + + // Replace the map type with struct type in the schema + types[mapColumnIndex] = finalStructSchema; + } + + // Build new schema using struct reading for flatmap columns + return ROW(names, types); +} + std::unique_ptr TableEvolutionFuzzer::makeScanTask( const RowTypePtr& tableSchema, std::vector splits, const PushdownConfig& pushdownConfig, - bool useFiltersAsNode) { + bool useFiltersAsNode, + const folly::F14FastMap>& + globalMapColumnKeys, + const std::vector& globallyCompatibleFlatmapColumns) { + // Build schema for flatmap as struct reading + RowTypePtr newSchemaUsingStructReadingFlatMap = buildFlatmapAsStructSchema( + tableSchema, globalMapColumnKeys, globallyCompatibleFlatmapColumns); + CursorParameters params; params.serialExecution = true; // TODO: Mix in filter and aggregate pushdowns. - params.planNode = PlanBuilder() - .filtersAsNode(useFiltersAsNode) - .tableScanWithPushDown( - tableSchema, - /*pushdownConfig=*/pushdownConfig, - tableSchema, - {}) - .planNode(); + params.planNode = + PlanBuilder() + .filtersAsNode(useFiltersAsNode) + .tableScanWithPushDown( + newSchemaUsingStructReadingFlatMap, // Use struct schema for + // flatmap reading + /*pushdownConfig=*/pushdownConfig, + tableSchema, // Original schema as dataColumns + {}) + .planNode(); auto cursor = TaskCursor::create(params); for (auto& split : splits) { cursor->task()->addSplit("0", std::move(split)); @@ -986,16 +1139,41 @@ void TableEvolutionFuzzer::createWriteTasks( const std::vector& bucketColumnIndices, const std::string& tableOutputRootDirPath, std::vector>& writeTasks, - RowVectorPtr& finalExpectedData) { + RowVectorPtr& finalExpectedData, + folly::F14FastMap>& globalMapColumnKeys, + std::vector& globallyConsistentColumnIndexVector) { + // Initialize globallyConsistentColumnIndexVector with all map column indices + // from the first schema, then filter out incompatible ones during processing + if (hasMapColumns(testSetups[0].schema)) { + for (int j = 0; j < testSetups[0].schema->size(); ++j) { + if (testSetups[0].schema->childAt(j)->isMap() && + !hasUnsupportedMapKey(testSetups[0].schema->childAt(j))) { + globallyConsistentColumnIndexVector.push_back(j); + } + } + } + + // Generate data and create write tasks in a single loop for (int i = 0; i < config_.evolutionCount; ++i) { + // Generate fresh data for each evolution step independently auto data = vectorFuzzer_.fuzzRow(testSetups[i].schema, kVectorSize, false); for (auto& child : data->children()) { BaseVector::flattenVector(child); } + auto actualDir = fmt::format("{}/actual_{}", tableOutputRootDirPath, i); VELOX_CHECK(std::filesystem::create_directory(actualDir)); + + // Pass globally consistent columns to restrict flatmap usage writeTasks[2 * i] = makeWriteTask( - testSetups[i], data, actualDir, bucketColumnIndices, rng_, true); + testSetups[i], + data, + actualDir, + bucketColumnIndices, + rng_, + true, + globalMapColumnKeys, + globallyConsistentColumnIndexVector); if (i == config_.evolutionCount - 1) { finalExpectedData = std::move(data); @@ -1012,7 +1190,9 @@ void TableEvolutionFuzzer::createWriteTasks( expectedDir, bucketColumnIndices, rng_, - true); + true, + globalMapColumnKeys, + globallyConsistentColumnIndexVector); } } diff --git a/velox/exec/tests/TableEvolutionFuzzer.h b/velox/exec/tests/TableEvolutionFuzzer.h index f00a25923a27..2e179a180d8d 100644 --- a/velox/exec/tests/TableEvolutionFuzzer.h +++ b/velox/exec/tests/TableEvolutionFuzzer.h @@ -105,7 +105,10 @@ class TableEvolutionFuzzer { const std::string& outputDir, const std::vector& bucketColumnIndices, FuzzerGenerator& rng, - bool enableFlatMap); + bool enableFlatMap, + folly::F14FastMap>& + globalMapColumnKeys, + std::vector& globallyCompatibleFlatmapColumns); template VectorPtr liftToPrimitiveType( @@ -118,7 +121,18 @@ class TableEvolutionFuzzer { const RowTypePtr& tableSchema, std::vector splits, const PushdownConfig& pushdownConfig, - bool useFiltersAsNode); + bool useFiltersAsNode, + const folly::F14FastMap>& + globalMapColumnKeys = {}, + const std::vector& globallyCompatibleFlatmapColumns = {}); + + /// Builds schema for flatmap as struct reading by converting selected map + /// columns to struct types. + RowTypePtr buildFlatmapAsStructSchema( + const RowTypePtr& tableSchema, + const folly::F14FastMap>& + globalMapColumnKeys, + const std::vector& globallyCompatibleFlatmapColumns); /// Randomly generates bucket column indices for partitioning data. /// Returns a vector of column indices that will be used for bucketing, @@ -134,7 +148,10 @@ class TableEvolutionFuzzer { const std::vector& bucketColumnIndices, const std::string& tableOutputRootDirPath, std::vector>& writeTasks, - RowVectorPtr& finalExpectedData); + RowVectorPtr& finalExpectedData, + folly::F14FastMap>& + globalMapColumnKeys, + std::vector& globallyConsistentColumnIndexVector); /// Creates scan splits from write results. /// Converts the output of write tasks into scan splits that can be used diff --git a/velox/exec/tests/TableScanTest.cpp b/velox/exec/tests/TableScanTest.cpp index 2bc53f8fcf66..e8ad602259f0 100644 --- a/velox/exec/tests/TableScanTest.cpp +++ b/velox/exec/tests/TableScanTest.cpp @@ -1870,6 +1870,15 @@ TEST_F(TableScanTest, partitionedTableDoubleKey) { testPartitionedTable(filePath->getPath(), DOUBLE(), "3.5"); } +TEST_F(TableScanTest, partitionedTableDecimalKey) { + auto rowType = ROW({"c0", "c1"}, {BIGINT(), DOUBLE()}); + auto vectors = makeVectors(10, 1'000, rowType); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), vectors); + createDuckDbTable(vectors); + testPartitionedTable(filePath->getPath(), DECIMAL(20, 4), "3.5123"); +} + TEST_F(TableScanTest, partitionedTableDateKey) { auto rowType = ROW({"c0", "c1"}, {BIGINT(), DOUBLE()}); auto vectors = makeVectors(10, 1'000, rowType); @@ -6112,5 +6121,28 @@ TEST_F(TableScanTest, duplicateFieldProject) { .assertResults("SELECT id, id FROM tmp WHERE name = 'John'"); } +TEST_F(TableScanTest, parallelUnitLoader) { + auto vectors = makeVectors(10, 1'000); + auto filePath = TempFilePath::create(); + writeToFile( + filePath->getPath(), + vectors, + std::make_shared(), + []() { return std::make_unique(1000, 0); }); + createDuckDbTable(vectors); + auto plan = tableScanNode(); + auto task = + AssertQueryBuilder(plan) + .splits(makeHiveConnectorSplits({filePath})) + .connectorSessionProperty( + kHiveConnectorId, + connector::hive::HiveConfig::kParallelUnitLoadCountSession, + std::to_string(3)) + .assertTypeAndNumRows(rowType_, 10'000); + auto stats = getTableScanRuntimeStats(task); + // Verify that parallel unit loader is enabled. + ASSERT_GT(stats.count("waitForUnitReadyNanos"), 0); +} + } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/TableWriterTest.cpp b/velox/exec/tests/TableWriterTest.cpp index 2b91f729c37e..863193b71a45 100644 --- a/velox/exec/tests/TableWriterTest.cpp +++ b/velox/exec/tests/TableWriterTest.cpp @@ -1822,7 +1822,7 @@ TEST_P(AllTableWriterTest, tableWriteOutputCheck) { ASSERT_EQ(writeFileName, targetFileName); } else { const std::string kParquetSuffix = ".parquet"; - if (folly::StringPiece(targetFileName).endsWith(kParquetSuffix)) { + if (targetFileName.ends_with(kParquetSuffix)) { // Remove the .parquet suffix. auto trimmedFilename = targetFileName.substr( 0, targetFileName.size() - kParquetSuffix.size()); @@ -2025,7 +2025,7 @@ TEST_P(AllTableWriterTest, columnStatsDataTypes) { const auto distinctCountStatsVector = result->childAt(nextColumnStatsIndex++)->asFlatVector(); HashStringAllocator allocator{pool_.get()}; - DenseHll denseHll{ + DenseHll<> denseHll{ std::string(distinctCountStatsVector->valueAt(0)).c_str(), &allocator}; ASSERT_EQ(denseHll.cardinality(), 1000); const auto maxDataSizeStatsVector = diff --git a/velox/exec/tests/TaskTest.cpp b/velox/exec/tests/TaskTest.cpp index e230dfbbd47c..0bb3e6e84db6 100644 --- a/velox/exec/tests/TaskTest.cpp +++ b/velox/exec/tests/TaskTest.cpp @@ -524,7 +524,8 @@ class TaskTest : public HiveConnectorTestBase { plan, 0, core::QueryCtx::create(), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); for (const auto& [nodeId, paths] : filePaths) { for (const auto& path : paths) { @@ -573,7 +574,8 @@ TEST_F(TaskTest, toJson) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); ASSERT_EQ( task->toString(), @@ -638,7 +640,8 @@ TEST_F(TaskTest, wrongPlanNodeForSplit) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); // Add split for the source node. task->addSplit("0", exec::Split(folly::copy(connectorSplit))); @@ -694,7 +697,8 @@ TEST_F(TaskTest, wrongPlanNodeForSplit) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); errorMessage = "Splits can be associated only with leaf plan nodes which require splits. Plan node ID 0 doesn't refer to such plan node."; VELOX_ASSERT_THROW( @@ -721,7 +725,8 @@ TEST_F(TaskTest, duplicatePlanNodeIds) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel), + Task::ExecutionMode::kParallel, + exec::Consumer{}), "Plan node IDs must be unique. Found duplicate ID: 0.") } @@ -1228,7 +1233,8 @@ TEST_F(TaskTest, serialExecutionExternalBlockable) { plan, 0, core::QueryCtx::create(), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); std::vector results; for (;;) { auto result = nonBlockingTask->next(&continueFuture); @@ -1254,7 +1260,8 @@ TEST_F(TaskTest, serialExecutionExternalBlockable) { plan, 0, core::QueryCtx::create(), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); // Before we block, we expect `next` to get data normally. results.push_back(blockingTask->next(&continueFuture)); EXPECT_TRUE(results.back() != nullptr); @@ -1291,16 +1298,17 @@ TEST_F(TaskTest, supportSerialExecutionMode) { .project({"c0 % 10"}) .partitionedOutput({}, 1, std::vector{"p0"}) .planFragment(); - auto task = Task::create( - "single.execution.task.0", - plan, - 0, - core::QueryCtx::create(), - Task::ExecutionMode::kSerial); - // PartitionedOutput does not support serial execution mode, therefore the // task doesn't support it either. - ASSERT_FALSE(task->supportSerialExecutionMode()); + VELOX_ASSERT_THROW( + Task::create( + "single.execution.task.0", + plan, + 0, + core::QueryCtx::create(), + Task::ExecutionMode::kSerial, + exec::Consumer{}), + ""); } TEST_F(TaskTest, updateBroadCastOutputBuffers) { @@ -1316,7 +1324,8 @@ TEST_F(TaskTest, updateBroadCastOutputBuffers) { plan, 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(1, 1); @@ -1334,7 +1343,8 @@ TEST_F(TaskTest, updateBroadCastOutputBuffers) { plan, 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(1, 1); @@ -1622,8 +1632,13 @@ DEBUG_ONLY_TEST_F(TaskTest, inconsistentExecutionMode) { auto plan = PlanBuilder().values({data, data, data}).project({"c0"}).planFragment(); auto queryCtx = core::QueryCtx::create(driverExecutor_.get()); - auto task = - Task::create("task.0", plan, 0, queryCtx, Task::ExecutionMode::kSerial); + auto task = Task::create( + "task.0", + plan, + 0, + queryCtx, + Task::ExecutionMode::kSerial, + exec::Consumer{}); task->next(); VELOX_ASSERT_THROW(task->start(4, 1), "Inconsistent task execution mode."); @@ -1957,17 +1972,24 @@ TEST_F(TaskTest, driverCreationMemoryAllocationCheck) { .planFragment(); for (bool singleThreadExecution : {false, true}) { SCOPED_TRACE(fmt::format("singleThreadExecution: ", singleThreadExecution)); - auto badTask = Task::create( - "driverCreationMemoryAllocationCheck", - plan, - 0, - core::QueryCtx::create( - singleThreadExecution ? nullptr : driverExecutor_.get()), - singleThreadExecution ? Task::ExecutionMode::kSerial - : Task::ExecutionMode::kParallel); if (singleThreadExecution) { - VELOX_ASSERT_THROW(badTask->next(), "Unexpected memory pool allocations"); + VELOX_ASSERT_THROW( + Task::create( + "driverCreationMemoryAllocationCheck", + plan, + 0, + core::QueryCtx::create(nullptr), + Task::ExecutionMode::kSerial, + exec::Consumer{}), + "Unexpected memory pool allocations"); } else { + auto badTask = Task::create( + "driverCreationMemoryAllocationCheck", + plan, + 0, + core::QueryCtx::create(driverExecutor_.get()), + Task::ExecutionMode::kParallel, + exec::Consumer{}); VELOX_ASSERT_THROW( badTask->start(1), "Unexpected memory pool allocations"); } @@ -1994,36 +2016,34 @@ TEST_F(TaskTest, spillDirectoryCallback) { {{core::QueryConfig::kSpillEnabled, "true"}, {core::QueryConfig::kAggregationSpillEnabled, "true"}}); params.maxDrivers = 1; - - auto cursor = TaskCursor::create(params); - - std::shared_ptr task = cursor->task(); - auto tmpRootDir = exec::test::TempDirectoryPath::create(); - auto tmpParentSpillDir = fmt::format( + auto spillRootDir = exec::test::TempDirectoryPath::create(); + auto spillParentDir = fmt::format( "{}{}/parent_spill/", tests::utils::FaultyFileSystem::scheme(), - tmpRootDir->getPath()); - auto tmpSpillDir = fmt::format( + spillRootDir->getPath()); + auto spillDir = fmt::format( "{}{}/parent_spill/spill/", tests::utils::FaultyFileSystem::scheme(), - tmpRootDir->getPath()); + spillRootDir->getPath()); - EXPECT_FALSE(task->hasCreateSpillDirectoryCb()); - - task->setCreateSpillDirectoryCb([tmpParentSpillDir, tmpSpillDir]() { - auto filesystem = filesystems::getFileSystem(tmpParentSpillDir, nullptr); + params.spillDirectory = spillDir; + params.spillDirectoryCallback = [spillParentDir, spillDir]() { + auto filesystem = filesystems::getFileSystem(spillParentDir, nullptr); filesystems::DirectoryOptions options; options.values.emplace( filesystems::DirectoryOptions::kMakeDirectoryConfig.toString(), "dummy.config=123"); - filesystem->mkdir(tmpParentSpillDir, options); - filesystem->mkdir(tmpSpillDir); - return tmpSpillDir; - }); + filesystem->mkdir(spillParentDir, options); + filesystem->mkdir(spillDir); + return spillDir; + }; + auto cursor = TaskCursor::create(params); + std::shared_ptr task = cursor->task(); EXPECT_TRUE(task->hasCreateSpillDirectoryCb()); + auto fs = std::dynamic_pointer_cast( - filesystems::getFileSystem(tmpParentSpillDir, nullptr)); + filesystems::getFileSystem(spillParentDir, nullptr)); fs->setFileSystemInjectionError( std::make_exception_ptr(std::runtime_error("test exception")), @@ -2039,7 +2059,7 @@ TEST_F(TaskTest, spillDirectoryCallback) { auto mkdirOp = static_cast(op); if (mkdirOp->path == - fmt::format("{}/parent_spill/", tmpRootDir->getPath())) { + fmt::format("{}/parent_spill/", spillRootDir->getPath())) { parentDirectoryCreated = true; auto it = mkdirOp->options.values.find( filesystems::DirectoryOptions::kMakeDirectoryConfig.toString()); @@ -2047,7 +2067,7 @@ TEST_F(TaskTest, spillDirectoryCallback) { EXPECT_EQ(it->second, "dummy.config=123"); } if (mkdirOp->path == - fmt::format("{}/parent_spill/spill/", tmpRootDir->getPath())) { + fmt::format("{}/parent_spill/spill/", spillRootDir->getPath())) { spillDirectoryCreated = true; } return; @@ -2092,13 +2112,13 @@ TEST_F(TaskTest, spillDirectoryLifecycleManagement) { {{core::QueryConfig::kSpillEnabled, "true"}, {core::QueryConfig::kAggregationSpillEnabled, "true"}}); params.maxDrivers = 1; + const auto rootTempDir = exec::test::TempDirectoryPath::create(); + const auto tmpDirectoryPath = + rootTempDir->getPath() + "/spillDirectoryLifecycleManagement"; + params.spillDirectory = tmpDirectoryPath; auto cursor = TaskCursor::create(params); std::shared_ptr task = cursor->task(); - auto rootTempDir = exec::test::TempDirectoryPath::create(); - auto tmpDirectoryPath = - rootTempDir->getPath() + "/spillDirectoryLifecycleManagement"; - task->setSpillDirectory(tmpDirectoryPath, false); TestScopedSpillInjection scopedSpillInjection(100); while (cursor->moveNext()) { @@ -2154,7 +2174,6 @@ TEST_F(TaskTest, spillDirNotCreated) { auto* task = cursor->task().get(); auto rootTempDir = exec::test::TempDirectoryPath::create(); auto tmpDirectoryPath = rootTempDir->getPath() + "/spillDirNotCreated"; - task->setSpillDirectory(tmpDirectoryPath, false); while (cursor->moveNext()) { } @@ -2200,7 +2219,8 @@ DEBUG_ONLY_TEST_F(TaskTest, resumeAfterTaskFinish) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); // Request pause and then unblock operators to proceed. @@ -2231,7 +2251,12 @@ DEBUG_ONLY_TEST_F(TaskTest, serialLongRunningOperatorInTaskReclaimerAbort) { auto queryCtx = core::QueryCtx::create(driverExecutor_.get()); auto blockingTask = Task::create( - "blocking.task.0", plan, 0, queryCtx, Task::ExecutionMode::kSerial); + "blocking.task.0", + plan, + 0, + queryCtx, + Task::ExecutionMode::kSerial, + exec::Consumer{}); // Before we block, we expect `next` to get data normally. EXPECT_NE(nullptr, blockingTask->next()); @@ -2306,7 +2331,12 @@ DEBUG_ONLY_TEST_F(TaskTest, longRunningOperatorInTaskReclaimerAbort) { auto queryCtx = core::QueryCtx::create(driverExecutor_.get()); auto blockingTask = Task::create( - "blocking.task.0", plan, 0, queryCtx, Task::ExecutionMode::kParallel); + "blocking.task.0", + plan, + 0, + queryCtx, + Task::ExecutionMode::kParallel, + exec::Consumer{}); blockingTask->start(4, 1); const std::string abortErrorMessage("Synthetic Exception"); @@ -2372,7 +2402,8 @@ DEBUG_ONLY_TEST_F(TaskTest, taskReclaimStats) { std::move(plan), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); const int numReclaims{10}; @@ -2446,7 +2477,8 @@ DEBUG_ONLY_TEST_F(TaskTest, taskPauseTime) { std::move(plan), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); // Wait for the task driver starts to run. @@ -2495,7 +2527,8 @@ TEST_F(TaskTest, updateStatsWhileCloseOffThreadDriver) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); std::this_thread::sleep_for(std::chrono::milliseconds(100)); task->testingVisitDrivers( @@ -2540,7 +2573,8 @@ DEBUG_ONLY_TEST_F(TaskTest, driverEnqueAfterFailedAndPausedTask) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); // Request pause. @@ -2660,7 +2694,8 @@ DEBUG_ONLY_TEST_F(TaskTest, taskCancellation) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); auto cancellationToken = task->getCancellationToken(); ASSERT_FALSE(cancellationToken.isCancellationRequested()); @@ -2735,6 +2770,9 @@ TEST_F(TaskTest, invalidPlanNodeForBarrier) { VELOX_ASSERT_THROW( task->requestBarrier(), "Name of the first node that doesn't support barriered execution:"); + while (auto next = task->next()) { + } + ASSERT_TRUE(task->isFinished()); } TEST_F(TaskTest, barrierAfterNoMoreSplits) { @@ -2769,6 +2807,9 @@ TEST_F(TaskTest, barrierAfterNoMoreSplits) { VELOX_ASSERT_THROW( task->requestBarrier(), "Can't start barrier on task which has already received no more splits"); + while (auto next = task->next()) { + } + ASSERT_TRUE(task->isFinished()); } TEST_F(TaskTest, invalidTaskModeForBarrier) { diff --git a/velox/exec/tests/ThreadDebugInfoTest.cpp b/velox/exec/tests/ThreadDebugInfoTest.cpp index 75bab5b37593..6bcb1801e145 100644 --- a/velox/exec/tests/ThreadDebugInfoTest.cpp +++ b/velox/exec/tests/ThreadDebugInfoTest.cpp @@ -63,6 +63,7 @@ template struct InduceSegFaultFunction { template void call(TResult& out, const TInput& in) { + LOG(ERROR) << "error"; int* nullpointer = nullptr; *nullpointer = 6; } @@ -117,9 +118,10 @@ DEBUG_ONLY_TEST_F(ThreadDebugInfoDeathTest, withinTheCallingThread) { #ifndef IS_BUILDING_WITH_SAN ASSERT_DEATH( - (task->next()), + task->next(), ".*Fatal signal handler. Query Id= TaskCursorQuery_0 Task Id= single.execution.task.0.*"); #endif + task->requestCancel(); } DEBUG_ONLY_TEST_F(ThreadDebugInfoDeathTest, noThreadContextSet) { diff --git a/velox/exec/tests/UnnestTest.cpp b/velox/exec/tests/UnnestTest.cpp index 53979988f057..afa8aba48705 100644 --- a/velox/exec/tests/UnnestTest.cpp +++ b/velox/exec/tests/UnnestTest.cpp @@ -285,7 +285,7 @@ TEST_P(UnnestTest, arrayWithOrdinality) { assertQuery(params, expectedInDict); } -TEST_P(UnnestTest, arrayWithEmptyUnnestValue) { +TEST_P(UnnestTest, arrayWithMarker) { const auto array = makeArrayVectorFromJson( {"[1, 2, null, 4]", "null", "[5, 6]", "[]", "[null]", "[7, 8, 9]"}); const auto input = makeRowVector( @@ -312,7 +312,7 @@ TEST_P(UnnestTest, arrayWithEmptyUnnestValue) { expected->childAt(1), makeNullableFlatVector({1, 2, 3, 4, 1, 2, 1, 1, 2, 3})}); - const auto expectedWithEmptyUnnestValue = makeRowVector( + const auto expectedWithMarker = makeRowVector( {makeNullableFlatVector( {1.1, 1.1, @@ -340,23 +340,23 @@ TEST_P(UnnestTest, arrayWithEmptyUnnestValue) { 8, 9}), makeNullableFlatVector( - {false, - false, - false, - false, + {true, + true, true, - false, - false, true, false, + true, + true, false, - false, - false})}); + true, + true, + true, + true})}); const auto expectedWithBoth = makeRowVector( - {expectedWithEmptyUnnestValue->childAt(0), - expectedWithEmptyUnnestValue->childAt(1), + {expectedWithMarker->childAt(0), + expectedWithMarker->childAt(1), makeNullableFlatVector({1, 2, 3, 4, 0, 1, 2, 0, 1, 1, 2, 3}), - expectedWithEmptyUnnestValue->childAt(2)}); + expectedWithMarker->childAt(2)}); struct { bool hasOrdinality; @@ -375,7 +375,7 @@ TEST_P(UnnestTest, arrayWithEmptyUnnestValue) { } testSettings[] = { {false, false, input, expected}, {true, false, input, expectedWithOrdinality}, - {false, true, input, expectedWithEmptyUnnestValue}, + {false, true, input, expectedWithMarker}, {true, true, input, expectedWithBoth}}; for (const auto& testData : testSettings) { @@ -385,13 +385,13 @@ TEST_P(UnnestTest, arrayWithEmptyUnnestValue) { if (testData.hasOrdinality) { ordinalityName = "ordinal"; } - std::optional emptyUnnestValueName; + std::optional markerName; if (testData.hasEmptyUnnestValue) { - emptyUnnestValueName = "emptyUnnestValue"; + markerName = "emptyUnnestValue"; } auto op = PlanBuilder() .values({testData.input}) - .unnest({"c0"}, {"c1"}, ordinalityName, emptyUnnestValueName) + .unnest({"c0"}, {"c1"}, ordinalityName, markerName) .planNode(); auto params = makeCursorParameters(op); assertQuery(params, testData.expected); @@ -466,7 +466,7 @@ TEST_P(UnnestTest, mapWithOrdinality) { assertQuery(params, expectedInDict); } -TEST_P(UnnestTest, mapWithEmptyUnnestValue) { +TEST_P(UnnestTest, mapWithMarker) { const auto map = makeNullableMapVector( {{{{1, 1.1}, {2, std::nullopt}}}, common::testutil::optionalEmpty, @@ -489,7 +489,7 @@ TEST_P(UnnestTest, mapWithEmptyUnnestValue) { expected->childAt(2), makeNullableFlatVector({1, 2, 1, 2, 3, 1})}); - const auto expectedWithEmptyUnnestValue = makeRowVector( + const auto expectedWithMarker = makeRowVector( {makeNullableFlatVector({1, 1, 2, 3, 3, 3, 4, 5}), makeNullableFlatVector( {1, 2, std::nullopt, 3, 4, 5, std::nullopt, 6}), @@ -503,14 +503,14 @@ TEST_P(UnnestTest, mapWithEmptyUnnestValue) { std::nullopt, std::nullopt}), makeNullableFlatVector( - {false, false, true, false, false, false, true, false})}); + {true, true, false, true, true, true, false, true})}); const auto expectedWithBoth = makeRowVector( - {expectedWithEmptyUnnestValue->childAt(0), - expectedWithEmptyUnnestValue->childAt(1), - expectedWithEmptyUnnestValue->childAt(2), + {expectedWithMarker->childAt(0), + expectedWithMarker->childAt(1), + expectedWithMarker->childAt(2), makeNullableFlatVector({1, 2, 0, 1, 2, 3, 0, 1}), - expectedWithEmptyUnnestValue->childAt(3)}); + expectedWithMarker->childAt(3)}); struct { bool hasOrdinality; @@ -529,7 +529,7 @@ TEST_P(UnnestTest, mapWithEmptyUnnestValue) { } testSettings[] = { {false, false, input, expected}, {true, false, input, expectedWithOrdinality}, - {false, true, input, expectedWithEmptyUnnestValue}, + {false, true, input, expectedWithMarker}, {true, true, input, expectedWithBoth}}; for (const auto& testData : testSettings) { @@ -539,13 +539,13 @@ TEST_P(UnnestTest, mapWithEmptyUnnestValue) { if (testData.hasOrdinality) { ordinalityName = "ordinal"; } - std::optional emptyUnnestValueName; + std::optional markerName; if (testData.hasEmptyUnnestValue) { - emptyUnnestValueName = "emptyUnnestValue"; + markerName = "emptyUnnestValue"; } auto op = PlanBuilder() .values({testData.input}) - .unnest({"c0"}, {"c1"}, ordinalityName, emptyUnnestValueName) + .unnest({"c0"}, {"c1"}, ordinalityName, markerName) .planNode(); auto params = makeCursorParameters(op); assertQuery(params, testData.expected); diff --git a/velox/exec/tests/VectorHasherTest.cpp b/velox/exec/tests/VectorHasherTest.cpp index 2a3cb8ef9f78..ecebd812b460 100644 --- a/velox/exec/tests/VectorHasherTest.cpp +++ b/velox/exec/tests/VectorHasherTest.cpp @@ -720,6 +720,11 @@ TEST_F(VectorHasherTest, merge) { EXPECT_EQ(numDistinct - 1, ids.size()); } +TEST_F(VectorHasherTest, computeValueIdsHugeInt) { + testComputeValueIds(false); + testComputeValueIds(true); +} + TEST_F(VectorHasherTest, computeValueIdsBigint) { testComputeValueIds(false); testComputeValueIds(true); diff --git a/velox/exec/tests/VeloxIn10MinDemo.cpp b/velox/exec/tests/VeloxIn10MinDemo.cpp index 87d571a1788c..59436835511e 100644 --- a/velox/exec/tests/VeloxIn10MinDemo.cpp +++ b/velox/exec/tests/VeloxIn10MinDemo.cpp @@ -47,25 +47,17 @@ class VeloxIn10MinDemo : public VectorTestBase { // Register type resolver with DuckDB SQL parser. parse::registerTypeResolver(); - // Register the TPC-H Connector Factory. - connector::registerConnectorFactory( - std::make_shared()); - // Create and register a TPC-H connector. - auto tpchConnector = - connector::getConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName) - ->newConnector( - kTpchConnectorId, - std::make_shared( - std::unordered_map())); + connector::tpch::TpchConnectorFactory factory; + auto tpchConnector = factory.newConnector( + kTpchConnectorId, + std::make_shared( + std::unordered_map())); connector::registerConnector(tpchConnector); } ~VeloxIn10MinDemo() { connector::unregisterConnector(kTpchConnectorId); - connector::unregisterConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName); } /// Parse SQL expression into a typed expression tree using DuckDB SQL parser. diff --git a/velox/exec/tests/WindowTest.cpp b/velox/exec/tests/WindowTest.cpp index f1399a7af177..e0b0d0c53812 100644 --- a/velox/exec/tests/WindowTest.cpp +++ b/velox/exec/tests/WindowTest.cpp @@ -115,6 +115,112 @@ TEST_F(WindowTest, spill) { ASSERT_GT(stats.spilledPartitions, 0); } +TEST_F(WindowTest, spillBatchReadTinyPartitions) { + const vector_size_t size = 1'000; + const uint32_t minReadBatchRows = 100; + // Each tiny partition has 1 row. + const uint32_t partitionRows = 1; + auto data = makeRowVector( + {"d", "p", "s"}, + { + // Payload. + makeFlatVector(size, [](auto row) { return row; }), + // Partition key. + makeFlatVector( + size, [](auto row) { return row / partitionRows; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s)"}) + .capturePlanNodeId(windowId) + .planNode(); + + auto spillDirectory = TempDirectoryPath::create(); + TestScopedSpillInjection scopedSpillInjection(100); + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config(core::QueryConfig::kSpillEnabled, "true") + .config(core::QueryConfig::kWindowSpillEnabled, "true") + .config( + core::QueryConfig::kWindowSpillMinReadBatchRows, minReadBatchRows) + .spillDirectory(spillDirectory->getPath()) + .assertResults( + "SELECT *, row_number() over (partition by p order by s) FROM tmp"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(windowId); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); + ASSERT_EQ( + stats.operatorStats.at("Window") + ->customStats[Window::kWindowSpillReadNumBatches] + .sum, + size / minReadBatchRows); +} + +TEST_F(WindowTest, spillBatchReadHugePartitions) { + const vector_size_t size = 1'000; + const uint32_t minReadBatchRows = 100; + // Each huge partition has 200 rows, which is larger than minReadBatchRows. + const uint32_t partitionRows = 200; + auto data = makeRowVector( + {"d", "p", "s"}, + { + // Payload. + makeFlatVector(size, [](auto row) { return row; }), + // Partition key. + makeFlatVector( + size, [](auto row) { return row / partitionRows; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s)"}) + .capturePlanNodeId(windowId) + .planNode(); + + auto spillDirectory = TempDirectoryPath::create(); + TestScopedSpillInjection scopedSpillInjection(100); + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config(core::QueryConfig::kSpillEnabled, "true") + .config(core::QueryConfig::kWindowSpillEnabled, "true") + .config( + core::QueryConfig::kWindowSpillMinReadBatchRows, minReadBatchRows) + .spillDirectory(spillDirectory->getPath()) + .assertResults( + "SELECT *, row_number() over (partition by p order by s) FROM tmp"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(windowId); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); + ASSERT_EQ( + stats.operatorStats.at("Window") + ->customStats[Window::kWindowSpillReadNumBatches] + .sum, + size / partitionRows); +} + TEST_F(WindowTest, spillUnsupported) { const vector_size_t size = 1'000; auto data = makeRowVector( @@ -658,12 +764,14 @@ DEBUG_ONLY_TEST_F(WindowTest, reserveMemorySort) { velox::common::PrefixSortConfig prefixSortConfig = velox::common::PrefixSortConfig{ std::numeric_limits::max(), 130, 12}; + folly::Synchronized opStats; auto sortWindowBuild = std::make_unique( plan, pool_.get(), std::move(prefixSortConfig), spillEnabled ? &spillConfig : nullptr, &nonReclaimableSection_, + &opStats, &spillStats); TestScopedSpillInjection scopedSpillInjection(0); diff --git a/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp b/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp index 9b4e85eb931d..817d128f3bc8 100644 --- a/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp +++ b/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp @@ -22,8 +22,7 @@ namespace fecebook::velox::exec::test { using namespace facebook::velox::test; namespace { -std::vector appendMatchColumn( - const std::vector columns) { +std::vector appendMarker(const std::vector columns) { std::vector resultColumns; resultColumns.reserve(columns.size() + 1); for (const auto& column : columns) { @@ -259,7 +258,7 @@ PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + bool hasMarker, JoinType joinType, const std::vector& outputColumns, PlanNodeId& joinNodeId) { @@ -267,14 +266,15 @@ PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( VELOX_CHECK_LE(leftKeys.size(), keyType_->size()); return PlanBuilder(planNodeIdGenerator, pool_.get()) .values(probeVectors) - .indexLookupJoin( - leftKeys, - rightKeys, - indexScanNode, - joinConditions, - includeMatchColumn, - includeMatchColumn ? appendMatchColumn(outputColumns) : outputColumns, - joinType) + .startIndexLookupJoin() + .leftKeys(leftKeys) + .rightKeys(rightKeys) + .indexSource(indexScanNode) + .joinConditions(joinConditions) + .hasMarker(hasMarker) + .outputLayout(hasMarker ? appendMarker(outputColumns) : outputColumns) + .joinType(joinType) + .endIndexLookupJoin() .capturePlanNodeId(joinNodeId) .planNode(); } @@ -285,7 +285,8 @@ PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + const std::string& filter, + bool hasMarker, JoinType joinType, const std::vector& outputColumns) { VELOX_CHECK_EQ(leftKeys.size(), rightKeys.size()); @@ -295,14 +296,16 @@ PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( .outputType(probeType_) .endTableScan() .captureScanNodeId(probeScanNodeId_) - .indexLookupJoin( - leftKeys, - rightKeys, - indexScanNode, - joinConditions, - includeMatchColumn, - includeMatchColumn ? appendMatchColumn(outputColumns) : outputColumns, - joinType) + .startIndexLookupJoin() + .leftKeys(leftKeys) + .rightKeys(rightKeys) + .indexSource(indexScanNode) + .joinConditions(joinConditions) + .filter(filter) + .hasMarker(hasMarker) + .outputLayout(hasMarker ? appendMarker(outputColumns) : outputColumns) + .joinType(joinType) + .endIndexLookupJoin() .capturePlanNodeId(joinNodeId_) .planNode(); } diff --git a/velox/exec/tests/utils/IndexLookupJoinTestBase.h b/velox/exec/tests/utils/IndexLookupJoinTestBase.h index 561b6dfe9feb..bcd714d9ec2e 100644 --- a/velox/exec/tests/utils/IndexLookupJoinTestBase.h +++ b/velox/exec/tests/utils/IndexLookupJoinTestBase.h @@ -86,7 +86,7 @@ class IndexLookupJoinTestBase : public HiveConnectorTestBase { /// @param probeVectors: the probe input vectors. /// @param leftKeys: the left join keys of index lookup join. /// @param rightKeys: the right join keys of index lookup join. - /// @param includeMatchColumn: whether the index join output includes a match + /// @param hasMarker: whether the index join output includes a match /// column at the end. /// @param joinType: the join type of index lookup join. /// @param outputColumns: the output column names of index lookup join. @@ -99,30 +99,32 @@ class IndexLookupJoinTestBase : public HiveConnectorTestBase { const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + bool hasMarker, core::JoinType joinType, const std::vector& outputColumns, core::PlanNodeId& joinNodeId); /// Makes lookup join plan with the following parameters: + /// @param planNodeIdGenerator: generator for creating unique plan node IDs. /// @param indexScanNode: the index table scan node. - /// @param probeVectors: the probe input vectors. /// @param leftKeys: the left join keys of index lookup join. /// @param rightKeys: the right join keys of index lookup join. - /// @param includeMatchColumn: whether the index join output includes a match + /// @param joinConditions: the join conditions for index lookup join that + /// can't be converted into simple equality join conditions. + /// @param filter: additional filter condition SQL string to apply on join + /// results. Can be empty string if no additional filter is needed. + /// @param hasMarker: whether the index join output includes a match /// column at the end. /// @param joinType: the join type of index lookup join. /// @param outputColumns: the output column names of index lookup join. - /// @param joinNodeId: returns the plan node id of the index lookup join - /// node. - /// @param probeScanNodeId: returns the plan node id of the probe table scan PlanNodePtr makeLookupPlan( const std::shared_ptr& planNodeIdGenerator, TableScanNodePtr indexScanNode, const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - bool includeMatchColumn, + const std::string& filter, + bool hasMarker, JoinType joinType, const std::vector& outputColumns); diff --git a/velox/exec/tests/utils/LocalExchangeSource.cpp b/velox/exec/tests/utils/LocalExchangeSource.cpp index d64dad293872..b988a22e1198 100644 --- a/velox/exec/tests/utils/LocalExchangeSource.cpp +++ b/velox/exec/tests/utils/LocalExchangeSource.cpp @@ -271,7 +271,7 @@ class LocalExchangeSource : public exec::ExchangeSource { } bool checkSetRequestPromise() { - VeloxPromise promise; + VeloxPromise promise{VeloxPromise::makeEmpty()}; { std::lock_guard l(queue_->mutex()); promise = std::move(promise_); diff --git a/velox/exec/tests/utils/MergeTestBase.h b/velox/exec/tests/utils/MergeTestBase.h index 54d4bdbac700..7af30741a739 100644 --- a/velox/exec/tests/utils/MergeTestBase.h +++ b/velox/exec/tests/utils/MergeTestBase.h @@ -15,8 +15,8 @@ */ #include "velox/common/base/Exceptions.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/common/time/Timer.h" -#include "velox/exec/TreeOfLosers.h" #include diff --git a/velox/exec/tests/utils/PlanBuilder.cpp b/velox/exec/tests/utils/PlanBuilder.cpp index d48914fd168d..6ca69d5056ac 100644 --- a/velox/exec/tests/utils/PlanBuilder.cpp +++ b/velox/exec/tests/utils/PlanBuilder.cpp @@ -718,7 +718,8 @@ PlanBuilder& PlanBuilder::tableWrite( const common::CompressionKind compressionKind, const RowTypePtr& schema, const bool ensureFiles, - const connector::CommitStrategy commitStrategy) { + const connector::CommitStrategy commitStrategy, + std::shared_ptr insertTableHandle) { return TableWriterBuilder(*this) .outputDirectoryPath(outputDirectoryPath) .outputFileName(outputFileName) @@ -735,6 +736,7 @@ PlanBuilder& PlanBuilder::tableWrite( .compressionKind(compressionKind) .ensureFiles(ensureFiles) .commitStrategy(commitStrategy) + .insertHandle(insertTableHandle) .endTableWriter(); } @@ -1741,20 +1743,35 @@ PlanBuilder& PlanBuilder::nestedLoopJoin( PlanBuilder& PlanBuilder::spatialJoin( const core::PlanNodePtr& right, const std::string& joinCondition, + const std::string& probeGeometry, + const std::string& buildGeometry, + const std::optional& radius, const std::vector& outputLayout, core::JoinType joinType) { VELOX_CHECK_NOT_NULL(planNode_, "SpatialJoin cannot be the source node"); - auto resultType = concat(planNode_->outputType(), right->outputType()); + auto probeType = planNode_->outputType(); + auto buildType = right->outputType(); + auto resultType = concat(probeType, buildType); auto outputType = extract(resultType, outputLayout); VELOX_CHECK(!joinCondition.empty(), "SpatialJoin condition cannot be empty"); core::TypedExprPtr joinConditionExpr = parseExpr(joinCondition, resultType, options_, pool_); + auto probeGeometryField = field(probeType, probeGeometry); + auto buildGeometryField = field(buildType, buildGeometry); + std::optional radiusField; + if (radius.has_value()) { + radiusField = field(buildType, radius.value()); + } + planNode_ = std::make_shared( nextPlanNodeId(), joinType, std::move(joinConditionExpr), + std::move(probeGeometryField), + std::move(buildGeometryField), + std::move(radiusField), std::move(planNode_), right, outputType); @@ -1965,12 +1982,13 @@ PlanBuilder& PlanBuilder::indexLookupJoin( const std::vector& rightKeys, const core::TableScanNodePtr& right, const std::vector& joinConditions, - bool includeMatchColumn, + const std::string& filter, + bool hasMarker, const std::vector& outputLayout, core::JoinType joinType) { VELOX_CHECK_NOT_NULL(planNode_, "indexLookupJoin cannot be the source node"); auto inputType = concat(planNode_->outputType(), right->outputType()); - if (includeMatchColumn) { + if (hasMarker) { auto names = inputType->names(); names.push_back(outputLayout.back()); auto types = inputType->children(); @@ -1988,13 +2006,20 @@ PlanBuilder& PlanBuilder::indexLookupJoin( parseIndexJoinCondition(joinCondition, inputType, pool_)); } + // Parse filter expression if provided + core::TypedExprPtr filterExpr; + if (!filter.empty()) { + filterExpr = parseExpr(filter, inputType, options_, pool_); + } + planNode_ = std::make_shared( nextPlanNodeId(), joinType, std::move(leftKeyFields), std::move(rightKeyFields), std::move(joinConditionPtrs), - includeMatchColumn, + filterExpr, + hasMarker, std::move(planNode_), right, std::move(outputType)); @@ -2006,7 +2031,7 @@ PlanBuilder& PlanBuilder::unnest( const std::vector& replicateColumns, const std::vector& unnestColumns, const std::optional& ordinalColumn, - const std::optional& emptyUnnestValueName) { + const std::optional& markerName) { VELOX_CHECK_NOT_NULL(planNode_, "Unnest cannot be the source node"); std::vector> replicateFields; @@ -2042,7 +2067,7 @@ PlanBuilder& PlanBuilder::unnest( unnestFields, unnestNames, ordinalColumn, - emptyUnnestValueName, + markerName, planNode_); VELOX_CHECK(planNode_->supportsBarrier()); return *this; @@ -2508,7 +2533,7 @@ core::PlanNodePtr PlanBuilder::IndexLookupJoinBuilder::build( planBuilder_.planNode_, "IndexLookupJoin cannot be the source node"); auto inputType = concat(planBuilder_.planNode_->outputType(), indexSource_->outputType()); - if (includeMatchColumn_) { + if (hasMarker_) { auto names = inputType->names(); names.push_back(outputLayout_.back()); auto types = inputType->children(); @@ -2528,13 +2553,21 @@ core::PlanNodePtr PlanBuilder::IndexLookupJoinBuilder::build( joinCondition, inputType, planBuilder_.pool_)); } + // Parse filter expression if provided + core::TypedExprPtr filterExpr; + if (!filter_.empty()) { + filterExpr = parseExpr( + filter_, inputType, planBuilder_.options_, planBuilder_.pool_); + } + return std::make_shared( id, joinType_, std::move(leftKeyFields), std::move(rightKeyFields), std::move(joinConditionPtrs), - includeMatchColumn_, + filterExpr, + hasMarker_, std::move(planBuilder_.planNode_), indexSource_, std::move(outputType)); diff --git a/velox/exec/tests/utils/PlanBuilder.h b/velox/exec/tests/utils/PlanBuilder.h index 96b65390a860..5e0c4740b50c 100644 --- a/velox/exec/tests/utils/PlanBuilder.h +++ b/velox/exec/tests/utils/PlanBuilder.h @@ -395,8 +395,8 @@ class PlanBuilder { return *this; } - IndexLookupJoinBuilder& includeMatchColumn(bool includeMatchColumn) { - includeMatchColumn_ = includeMatchColumn; + IndexLookupJoinBuilder& hasMarker(bool hasMarker) { + hasMarker_ = hasMarker; return *this; } @@ -406,6 +406,13 @@ class PlanBuilder { return *this; } + /// @param filter SQL expression for the additional join filter. Can + /// use columns from both probe and build sides of the join. + IndexLookupJoinBuilder& filter(std::string filter) { + filter_ = std::move(filter); + return *this; + } + /// @param joinType Type of the join supported: inner, left. IndexLookupJoinBuilder& joinType(core::JoinType joinType) { joinType_ = joinType; @@ -427,7 +434,8 @@ class PlanBuilder { std::vector rightKeys_; core::TableScanNodePtr indexSource_; std::vector joinConditions_; - bool includeMatchColumn_{false}; + std::string filter_; + bool hasMarker_{false}; std::vector outputLayout_; core::JoinType joinType_{core::JoinType::kInner}; }; @@ -817,6 +825,12 @@ class PlanBuilder { /// output of the previous operator. /// @param ensureFiles When this option is set the HiveDataSink will always /// create a file even if there is no data. + /// @param commitStrategy The commit strategy to use for the table write + /// operation, default is kNoCommit. + /// @param insertTableHandle Encapsulates information needed to write data + /// to a table through a connector. If not specified, tableWrite will build + /// a HiveInsertTableHandle with columnHandles, bucketProperty and + /// locationHandle. PlanBuilder& tableWrite( const std::string& outputDirectoryPath, const std::vector& partitionBy, @@ -835,7 +849,8 @@ class PlanBuilder { const RowTypePtr& schema = nullptr, const bool ensureFiles = false, const connector::CommitStrategy commitStrategy = - connector::CommitStrategy::kNoCommit); + connector::CommitStrategy::kNoCommit, + std::shared_ptr insertTableHandle = nullptr); /// Add a TableWriteMergeNode. PlanBuilder& tableWriteMerge(); @@ -1302,6 +1317,9 @@ class PlanBuilder { PlanBuilder& spatialJoin( const core::PlanNodePtr& right, const std::string& joinCondition, + const std::string& probeGeometry, + const std::string& buildGeometry, + const std::optional& radius, const std::vector& outputLayout, core::JoinType joinType = core::JoinType::kInner); @@ -1315,6 +1333,11 @@ class PlanBuilder { /// node. Second input is specified in 'right' parameter and must be a /// table source with the connector table handle with index lookup support. /// + /// @param leftKeys Join keys from the probe side, the preceding plan node. + /// Cannot be empty. + /// @param rightKeys Join keys from the index lookup side, the plan node + /// specified in 'right' parameter. The number and types of left and right + /// keys must be the same. /// @param right The right input source with index lookup support. /// @param joinConditions SQL expressions as the join conditions. Each join /// condition must use columns from both sides. For the right side, it can @@ -1327,18 +1350,23 @@ class PlanBuilder { /// where "a" is the index column from right side and "b", "c" are either /// condition column from left side or a constant but at least one of them /// must not be constant. They all have the same type. - /// @param joinType Type of the join supported: inner, left. - /// @param includeMatchColumn if true, 'outputLayout' should include a boolean + /// @param filter SQL expression for the additional join filter to apply on + /// join results. This supports filters that can't be converted into join + /// conditions or lookup conditions. Can be an empty string if no additional + /// filter is needed. + /// @param hasMarker if true, 'outputLayout' should include a boolean /// column at the end to indicate if a join output row has a match or not. /// This only applies for left join. - /// - /// See hashJoin method for the description of the other parameters. + /// @param outputLayout Output layout consisting of columns from probe and + /// build sides. + /// @param joinType Type of the join supported: inner, left. PlanBuilder& indexLookupJoin( const std::vector& leftKeys, const std::vector& rightKeys, const core::TableScanNodePtr& right, const std::vector& joinConditions, - bool includeMatchColumn, + const std::string& filter, + bool hasMarker, const std::vector& outputLayout, core::JoinType joinType = core::JoinType::kInner); @@ -1360,16 +1388,16 @@ class PlanBuilder { /// @param ordinalColumn An optional name for the 'ordinal' column to produce. /// This column contains the index of the element of the unnested array or /// map. If not specified, the output will not contain this column. - /// @param emptyUnnestValueName An optional name for the - /// 'emptyUnnestValue' column to produce. This column contains a boolean - /// indicating if the output row has empty unnest value or not. If not - /// specified, the output will not contain this column and the unnest operator - /// also skips producing output rows with empty unnest value. + /// @param markerName An optional name for the marker column to produce. + /// This column contains a boolean indicating whether the output row has + /// non-empty unnested value. If not specified, the output will not contain + /// this column and the unnest operator also skips producing output rows + /// with empty unnest value. PlanBuilder& unnest( const std::vector& replicateColumns, const std::vector& unnestColumns, const std::optional& ordinalColumn = std::nullopt, - const std::optional& emptyUnnestValueName = std::nullopt); + const std::optional& markerName = std::nullopt); /// Add a WindowNode to compute one or more windowFunctions. /// @param windowFunctions A list of one or more window function SQL like diff --git a/velox/exec/tests/utils/QueryAssertions.cpp b/velox/exec/tests/utils/QueryAssertions.cpp index 58adaf6f9974..61842c1e5172 100644 --- a/velox/exec/tests/utils/QueryAssertions.cpp +++ b/velox/exec/tests/utils/QueryAssertions.cpp @@ -1441,6 +1441,15 @@ void waitForAllTasksToBeDeleted(uint64_t maxWaitUs) { folly::join("\n", pendingTaskStats)); } +void cancelAllTasks() { + std::vector> pendingTasks = Task::getRunningTasks(); + for (const auto& task : pendingTasks) { + if (task->isRunning()) { + task->requestCancel(); + } + } +} + std::shared_ptr assertQuery( const core::PlanNodePtr& plan, std::function addSplits, diff --git a/velox/exec/tests/utils/QueryAssertions.h b/velox/exec/tests/utils/QueryAssertions.h index 3acfa88885d0..182969cd3ea0 100644 --- a/velox/exec/tests/utils/QueryAssertions.h +++ b/velox/exec/tests/utils/QueryAssertions.h @@ -221,6 +221,11 @@ bool waitForTaskStateChange( /// during this wait call. This is for testing purpose for now. void waitForAllTasksToBeDeleted(uint64_t maxWaitUs = 3'000'000); +/// Cancels all currently running tasks across all available task managers. +/// This is primarily used in testing scenarios to clean up active tasks +/// and ensure test isolation between test cases. +void cancelAllTasks(); + std::shared_ptr assertQuery( const core::PlanNodePtr& plan, const std::string& duckDbSql, diff --git a/velox/exec/tests/utils/TableScanTestBase.cpp b/velox/exec/tests/utils/TableScanTestBase.cpp index 394020fe1b5a..71392147e4a5 100644 --- a/velox/exec/tests/utils/TableScanTestBase.cpp +++ b/velox/exec/tests/utils/TableScanTestBase.cpp @@ -173,6 +173,11 @@ void TableScanTestBase::testPartitionedTableImpl( std::string partitionValueStr; partitionValueStr = partitionValue.has_value() ? "'" + *partitionValue + "'" : "null"; + if (partitionValue.has_value() && partitionType->isDecimal()) { + auto [p, s] = getDecimalPrecisionScale(*partitionType); + partitionValueStr = + fmt::format("CAST({} AS DECIMAL({}, {}))", partitionValueStr, p, s); + } assertQuery( op, split, fmt::format("SELECT {}, * FROM tmp", partitionValueStr)); diff --git a/velox/experimental/cudf/CudfConfig.h b/velox/experimental/cudf/CudfConfig.h new file mode 100644 index 000000000000..ebe405e0018d --- /dev/null +++ b/velox/experimental/cudf/CudfConfig.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace facebook::velox::cudf_velox { + +struct CudfConfig { + /// Keys used by the initialize() method. + static constexpr const char* kCudfEnabled{"cudf.enabled"}; + static constexpr const char* kCudfDebugEnabled{"cudf.debug_enabled"}; + static constexpr const char* kCudfMemoryResource{"cudf.memory_resource"}; + static constexpr const char* kCudfMemoryPercent{"cudf.memory_percent"}; + static constexpr const char* kCudfFunctionNamePrefix{ + "cudf.function_name_prefix"}; + static constexpr const char* kCudfForceReplace{"cudf.force_replace"}; + + /// Singleton CudfConfig instance. + /// Clients must set the configs below before invoking registerCudf(). + static CudfConfig& getInstance(); + + /// Initialize from a map with the above keys. + void initialize(std::unordered_map&&); + + /// Enable cudf by default. + /// Clients can disable here and enable it via the QueryConfig as well. + bool enabled{true}; + + /// Enable debug printing. + bool debugEnabled{false}; + + /// Memory resource for cuDF. + /// Possible values are (cuda, pool, async, arena, managed, managed_pool). + std::string memoryResource{"async"}; + + /// The initial percent of GPU memory to allocate for pool or arena memory + /// resources. + int32_t memoryPercent{50}; + + /// Register all the functions with the functionNamePrefix. + std::string functionNamePrefix; + + /// Force replacement of operators. Throws an error if a replacement fails. + bool forceReplace{false}; +}; + +} // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/README.md b/velox/experimental/cudf/README.md new file mode 100644 index 000000000000..5a4d68f25031 --- /dev/null +++ b/velox/experimental/cudf/README.md @@ -0,0 +1,66 @@ +# Velox-cuDF + +Velox-cuDF is a Velox extension module that uses the cuDF library to implement a GPU-accelerated backend for executing Velox plans. [cuDF](https://github.com/rapidsai/cudf) is an open source library for GPU data processing, and Velox-cuDF integrates with "[libcudf](https://github.com/rapidsai/cudf/tree/main/cpp)", the CUDA C++ core of cuDF. libcudf uses [Arrow](https://arrow.apache.org)-compatible data layouts and includes single-node, single-GPU algorithms for data processing. + +## How Velox and cuDF work together + +Velox-cuDF implements the Velox [DriverAdapter](https://github.com/facebookincubator/velox/blob/d9f953cd23880f29593534f1ba9031c6cea8ba06/velox/exec/Driver.h#L695) interface as [CudfDriverAdapter](https://github.com/facebookincubator/velox/blob/226b92cefedce4b8a484bfc351260edbd3d2e501/velox/experimental/cudf/exec/ToCudf.cpp#L301) to rewrite query plans for GPU execution. Generally the cuDF DriverAdapter replaces operators one-to-one. For end-to-end GPU execution where cuDF replaces all of the Velox CPU operators, cuDF relies on Velox's [pipeline-based execution model](https://facebookincubator.github.io/velox/develop/task.html) to separate stages of execution, partition the work across drivers, and schedule concurrent work on the GPU. + +For more information please refer to our blog: "[Extending Velox - GPU Acceleration with cuDF](https://velox-lib.io/blog/extending-velox-with-cudf)." + +## Getting started with Velox-cuDF + +cuDF supports Linux and WSL2 but not Windows or MacOS. cuDF also has minimum CUDA version, NVIDIA driver and GPU architecture requirements which can be found in the [RAPIDS Installation Guide](https://docs.rapids.ai/install/). Please refer to cuDF's [readme](https://github.com/rapidsai/cudf) and [developer guide](https://github.com/rapidsai/cudf/blob/main/cpp/doxygen/developer_guide/DEVELOPER_GUIDE.md) for more information. + +### Building Velox with cuDF + +The cuDF backend is included in Velox builds when the [VELOX_ENABLE_CUDF](https://github.com/facebookincubator/velox/blob/43df50c4f24bcbfa96f5739c072ab0894d41cf4c/CMakeLists.txt#L455) CMake option is set. The `adapters-cuda` service in Velox's [docker-compose.yml](https://github.com/facebookincubator/velox/blob/43df50c4f24bcbfa96f5739c072ab0894d41cf4c/docker-compose.yml#L69) is an excellent starting point for Velox builds with cuDF. + +1. Use `docker compose` to run an `adapters-cuda` image. +``` +$ docker compose -f docker-compose.yml run -e NUM_THREADS=8 --rm -v "$(pwd):/velox" adapters-cuda /bin/bash +``` +2. Once inside the image, build cuDF with the following flags: +``` +$ CUDA_ARCHITECTURES="native" EXTRA_CMAKE_FLAGS="-DVELOX_ENABLE_ARROW=ON -DVELOX_ENABLE_PARQUET=ON -DVELOX_ENABLE_BENCHMARKS=ON -DVELOX_ENABLE_BENCHMARKS_BASIC=ON" make cudf +``` +3. After cuDF is built, verify the build by running the unit tests. +``` +$ cd _build/release +$ ctest -R cudf -V +``` + +Velox-cuDF builds are included in Velox CI as part of the [adapters build](https://github.com/facebookincubator/velox/blob/de31a3eb07b5ec3cbd1e6320a989fcb2ee1a95a7/.github/workflows/linux-build-base.yml#L85). The build step for cuDF does not require the worker to have a GPU, so adding a Velox-cuDF build step to Velox CI is compatible with the existing runners. + +### Testing Velox with cuDF + +Tests with Velox-cuDF can only be run on GPU-enabled hardware. The Velox-cuDF tests in [experimental/cudf/tests](https://github.com/facebookincubator/velox/blob/main/velox/experimental/cudf/tests) include several types of tests: +* operator tests +* function tests +* fuzz tests (not yet implemented) + +The repo [rapidsai/velox-testing](https://github.com/rapidsai/velox-testing/) includes standard scripts for testing Velox-cuDF. Please refer to the [test_velox.sh](https://github.com/rapidsai/velox-testing/blob/main/velox/scripts/test_velox.sh) for running the Velox-cuDF unit tests. We plan to first develop GitHub Actions for GPU CI in [rapidsai/velox-testing](https://github.com/rapidsai/velox-testing/), and then later transition GPU-enabled GitHub Actions to Velox mainline. + +#### Operator tests + +Many of the tests for cuDF are "operator tests" which confirm correct execution of simple query plans. cuDF's operator tests use `CudfDriverAdapter` to modify the test plan with GPU operators before executing it. The operator tests for cuDF include both tests that assert successful GPU operator replacement, and tests that pass with CPU fallback. + +#### Function tests + +Velox-cuDF also includes "function tests" which cover the behavior of shared functions that could be called in multiple operators. Velox-cuDF function tests assess the correctness of functions using one or more cuDF API calls to provide the output. [SubfieldFilterAstTest](https://github.com/facebookincubator/velox/blob/99a04b94eed42d1c35ae99101da3bf77b31652e8/velox/experimental/cudf/tests/SubfieldFilterAstTest.cpp#L158) includes several examples of function tests. Please note that unit tests for cuDF APIs are included in [cudf/cpp/tests](https://github.com/rapidsai/cudf/tree/branch-25.10/cpp/tests) rather than Velox. + +#### Fuzz tests + +Velox includes components for "fuzz testing" to ensure robustness of Velox operators. For instance, the [Join Fuzzer](https://github.com/facebookincubator/velox/blob/99a04b94eed42d1c35ae99101da3bf77b31652e8/velox/docs/develop/testing/join-fuzzer.rst) executes a random join type with random inputs and compares the Velox results with a reference query engine. Fuzz testing tools have been used for cuDF operator development, but fuzz testing for cuDF is not yet integrated into Velox mainline. + +### Benchmarking Velox with cuDF + +Velox's [TpchBenchmark](https://github.com/facebookincubator/velox/blob/d9f953cd23880f29593534f1ba9031c6cea8ba06/velox/benchmarks/tpch/TpchBenchmark.cpp) is derived from [TPC-H](https://www.tpc.org/tpch/) and provides a convenient tool for benchmarking Velox's performance with OLAP (Online Analytical Processing) workloads. Velox-cuDF includes GPU operators for the hand-built query plans located in [TpchQueryBuilder](https://github.com/facebookincubator/velox/blob/43df50c4f24bcbfa96f5739c072ab0894d41cf4c/velox/exec/tests/utils/TpchQueryBuilder.cpp). Velox [PR 13695](https://github.com/facebookincubator/velox/pull/13695) extends Velox's TpchBenchmark to the cuDF backend. + +Please note that Velox's hand-built query plans require the data set to have floating-point types in place of the fixed-point types defined in the standard. Further development of Velox's TpchBenchmark could allow correct behavior with both fixed-point and floating-point types. + +## Contributing + +Velox-cuDF's development priorities are documented as Velox issues using the "[cuDF]" prefix. Please check out the [open issues](https://github.com/facebookincubator/velox/issues?q=is%3Aissue%20state%3Aopen%20%5BcuDF%5D) to learn more. + +We would love to hear from you in Velox's Slack workspace, please see Velox discussion [11348](https://github.com/facebookincubator/velox/discussions/11348) for information on joining. diff --git a/velox/experimental/cudf/connectors/CMakeLists.txt b/velox/experimental/cudf/connectors/CMakeLists.txt index 37a9408221cc..d73b2c3e1286 100644 --- a/velox/experimental/cudf/connectors/CMakeLists.txt +++ b/velox/experimental/cudf/connectors/CMakeLists.txt @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_subdirectory(parquet) +add_subdirectory(hive) diff --git a/velox/experimental/cudf/connectors/parquet/CMakeLists.txt b/velox/experimental/cudf/connectors/hive/CMakeLists.txt similarity index 50% rename from velox/experimental/cudf/connectors/parquet/CMakeLists.txt rename to velox/experimental/cudf/connectors/hive/CMakeLists.txt index 171b951f2524..8562dc5dbda5 100644 --- a/velox/experimental/cudf/connectors/parquet/CMakeLists.txt +++ b/velox/experimental/cudf/connectors/hive/CMakeLists.txt @@ -12,26 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_library(velox_cudf_parquet_config ParquetConfig.cpp) - -set_target_properties(velox_cudf_parquet_config PROPERTIES CUDA_ARCHITECTURES native) - -target_link_libraries(velox_cudf_parquet_config velox_core velox_exception cudf::cudf) - add_library( - velox_cudf_parquet_connector - OBJECT - ParquetConfig.cpp - ParquetConnector.cpp - ParquetConnectorSplit.cpp - ParquetDataSource.cpp - ParquetDataSink.cpp - ParquetTableHandle.cpp + velox_cudf_hive_connector + CudfHiveConfig.cpp + CudfHiveConnector.cpp + CudfHiveConnectorSplit.cpp + CudfHiveDataSource.cpp + CudfHiveDataSink.cpp + CudfHiveTableHandle.cpp ) -set_target_properties(velox_cudf_parquet_connector PROPERTIES CUDA_ARCHITECTURES native) +set_target_properties(velox_cudf_hive_connector PROPERTIES CUDA_ARCHITECTURES native) -target_link_libraries( - velox_cudf_parquet_connector - PRIVATE cudf::cudf velox_common_io velox_connector velox_type_tz velox_gcs -) +target_link_libraries(velox_cudf_hive_connector PRIVATE velox_hive_connector PUBLIC cudf::cudf) diff --git a/velox/experimental/cudf/connectors/parquet/ParquetConfig.cpp b/velox/experimental/cudf/connectors/hive/CudfHiveConfig.cpp similarity index 73% rename from velox/experimental/cudf/connectors/parquet/ParquetConfig.cpp rename to velox/experimental/cudf/connectors/hive/CudfHiveConfig.cpp index 154498379303..13ce62286286 100644 --- a/velox/experimental/cudf/connectors/parquet/ParquetConfig.cpp +++ b/velox/experimental/cudf/connectors/hive/CudfHiveConfig.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "velox/experimental/cudf/connectors/parquet/ParquetConfig.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConfig.h" #include "velox/common/base/Exceptions.h" #include "velox/common/config/Config.h" @@ -23,25 +23,25 @@ #include -namespace facebook::velox::cudf_velox::connector::parquet { +namespace facebook::velox::cudf_velox::connector::hive { -int64_t ParquetConfig::skipRows() const { +int64_t CudfHiveConfig::skipRows() const { return config_->get(kSkipRows, 0); } -std::optional ParquetConfig::numRows() const { +std::optional CudfHiveConfig::numRows() const { auto numRows = config_->get(kNumRows); return numRows.has_value() ? std::make_optional(numRows.value()) : std::nullopt; } -std::size_t ParquetConfig::maxChunkReadLimit() const { +std::size_t CudfHiveConfig::maxChunkReadLimit() const { // chunk read limit = 0 means no limit return config_->get(kMaxChunkReadLimit, 0); } -std::size_t ParquetConfig::maxChunkReadLimitSession( +std::size_t CudfHiveConfig::maxChunkReadLimitSession( const config::ConfigBase* session) const { // pass read limit = 0 means no limit return session->get( @@ -49,12 +49,12 @@ std::size_t ParquetConfig::maxChunkReadLimitSession( config_->get(kMaxChunkReadLimit, 0)); } -std::size_t ParquetConfig::maxPassReadLimit() const { +std::size_t CudfHiveConfig::maxPassReadLimit() const { // pass read limit = 0 means no limit return config_->get(kMaxPassReadLimit, 0); } -std::size_t ParquetConfig::maxPassReadLimitSession( +std::size_t CudfHiveConfig::maxPassReadLimitSession( const config::ConfigBase* session) const { // pass read limit = 0 means no limit return session->get( @@ -62,49 +62,49 @@ std::size_t ParquetConfig::maxPassReadLimitSession( config_->get(kMaxPassReadLimit, 0)); } -bool ParquetConfig::isConvertStringsToCategories() const { +bool CudfHiveConfig::isConvertStringsToCategories() const { return config_->get(kConvertStringsToCategories, false); } -bool ParquetConfig::isConvertStringsToCategoriesSession( +bool CudfHiveConfig::isConvertStringsToCategoriesSession( const config::ConfigBase* session) const { return session->get( kConvertStringsToCategoriesSession, config_->get(kConvertStringsToCategories, false)); } -bool ParquetConfig::isUsePandasMetadata() const { +bool CudfHiveConfig::isUsePandasMetadata() const { return config_->get(kUsePandasMetadata, true); } -bool ParquetConfig::isUsePandasMetadataSession( +bool CudfHiveConfig::isUsePandasMetadataSession( const config::ConfigBase* session) const { return session->get( kUsePandasMetadataSession, config_->get(kUsePandasMetadata, true)); } -bool ParquetConfig::isUseArrowSchema() const { +bool CudfHiveConfig::isUseArrowSchema() const { return config_->get(kUseArrowSchema, true); } -bool ParquetConfig::isUseArrowSchemaSession( +bool CudfHiveConfig::isUseArrowSchemaSession( const config::ConfigBase* session) const { return session->get( kUseArrowSchemaSession, config_->get(kUseArrowSchema, true)); } -bool ParquetConfig::isAllowMismatchedParquetSchemas() const { - return config_->get(kAllowMismatchedParquetSchemas, false); +bool CudfHiveConfig::isAllowMismatchedCudfHiveSchemas() const { + return config_->get(kAllowMismatchedCudfHiveSchemas, false); } -bool ParquetConfig::isAllowMismatchedParquetSchemasSession( +bool CudfHiveConfig::isAllowMismatchedCudfHiveSchemasSession( const config::ConfigBase* session) const { return session->get( - kAllowMismatchedParquetSchemasSession, - config_->get(kAllowMismatchedParquetSchemas, false)); + kAllowMismatchedCudfHiveSchemasSession, + config_->get(kAllowMismatchedCudfHiveSchemas, false)); } -cudf::data_type ParquetConfig::timestampType() const { +cudf::data_type CudfHiveConfig::timestampType() const { const auto unit = config_->get( kTimestampType, cudf::type_id::TIMESTAMP_MILLISECONDS /*milli*/); VELOX_CHECK( @@ -117,7 +117,7 @@ cudf::data_type ParquetConfig::timestampType() const { return cudf::data_type(cudf::type_id{unit}); } -cudf::data_type ParquetConfig::timestampTypeSession( +cudf::data_type CudfHiveConfig::timestampTypeSession( const config::ConfigBase* session) const { const auto unit = session->get( kTimestampTypeSession, @@ -133,47 +133,47 @@ cudf::data_type ParquetConfig::timestampTypeSession( return cudf::data_type(cudf::type_id{unit}); } -bool ParquetConfig::immutableFiles() const { +bool CudfHiveConfig::immutableFiles() const { return config_->get(kImmutableFiles, false); } -uint64_t ParquetConfig::sortWriterFinishTimeSliceLimitMs( +uint64_t CudfHiveConfig::sortWriterFinishTimeSliceLimitMs( const config::ConfigBase* session) const { return session->get( kSortWriterFinishTimeSliceLimitMsSession, config_->get(kSortWriterFinishTimeSliceLimitMs, 5'000)); } -bool ParquetConfig::writeTimestampsAsUTC() const { +bool CudfHiveConfig::writeTimestampsAsUTC() const { return config_->get(kWriteTimestampsAsUTC, true); } -bool ParquetConfig::writeTimestampsAsUTCSession( +bool CudfHiveConfig::writeTimestampsAsUTCSession( const config::ConfigBase* session) const { return session->get( kWriteTimestampsAsUTCSession, config_->get(kWriteTimestampsAsUTC, true)); } -bool ParquetConfig::writeArrowSchema() const { +bool CudfHiveConfig::writeArrowSchema() const { return config_->get(kWriteArrowSchema, false); } -bool ParquetConfig::writeArrowSchemaSession( +bool CudfHiveConfig::writeArrowSchemaSession( const config::ConfigBase* session) const { return session->get( kWriteArrowSchemaSession, config_->get(kWriteArrowSchema, false)); } -bool ParquetConfig::writev2PageHeaders() const { +bool CudfHiveConfig::writev2PageHeaders() const { return config_->get(kWritev2PageHeaders, false); } -bool ParquetConfig::writev2PageHeadersSession( +bool CudfHiveConfig::writev2PageHeadersSession( const config::ConfigBase* session) const { return session->get( kWritev2PageHeadersSession, config_->get(kWritev2PageHeaders, false)); } -} // namespace facebook::velox::cudf_velox::connector::parquet +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/parquet/ParquetConfig.h b/velox/experimental/cudf/connectors/hive/CudfHiveConfig.h similarity index 88% rename from velox/experimental/cudf/connectors/parquet/ParquetConfig.h rename to velox/experimental/cudf/connectors/hive/CudfHiveConfig.h index bf80ad8b0ec6..d2977d93f719 100644 --- a/velox/experimental/cudf/connectors/parquet/ParquetConfig.h +++ b/velox/experimental/cudf/connectors/hive/CudfHiveConfig.h @@ -26,14 +26,14 @@ namespace facebook::velox::config { class ConfigBase; } -namespace facebook::velox::cudf_velox::connector::parquet { +namespace facebook::velox::cudf_velox::connector::hive { -class ParquetConfig { +class CudfHiveConfig { public: // Reader config options - // Number of rows to skip from the start; Parquet stores the number of rows as - // int64_t + // Number of rows to skip from the start; CudfHive stores the number of rows + // as int64_t static constexpr const char* kSkipRows = "parquet.reader.skip-rows"; // Number of rows to read; `nullopt` is all @@ -69,11 +69,11 @@ class ParquetConfig { static constexpr const char* kUseArrowSchemaSession = "parquet.reader.use_arrow_schema"; - // Whether to allow reading matching select columns from mismatched Parquet + // Whether to allow reading matching select columns from mismatched CudfHive // files. - static constexpr const char* kAllowMismatchedParquetSchemas = + static constexpr const char* kAllowMismatchedCudfHiveSchemas = "parquet.reader.allow-mismatched-parquet-schemas"; - static constexpr const char* kAllowMismatchedParquetSchemasSession = + static constexpr const char* kAllowMismatchedCudfHiveSchemasSession = "parquet.reader.allow_mismatched_parquet_schemas"; // Cast timestamp columns to a specific type @@ -83,7 +83,7 @@ class ParquetConfig { // Writer config options - /// Whether new data can be inserted into a Parquet file + /// Whether new data can be inserted into a CudfHive file /// Cudf-Velox currently does not support appending data to existing files. static constexpr const char* kImmutableFiles = "parquet.immutable-files"; @@ -109,9 +109,9 @@ class ParquetConfig { static constexpr const char* kWritev2PageHeadersSession = "parquet.writer.write_v2_page_headers"; - ParquetConfig(std::shared_ptr config) { + CudfHiveConfig(std::shared_ptr config) { VELOX_CHECK_NOT_NULL( - config, "Config is null for ParquetConfig initialization"); + config, "Config is null for CudfHiveConfig initialization"); config_ = std::move(config); } @@ -141,8 +141,8 @@ class ParquetConfig { bool isUseArrowSchema() const; bool isUseArrowSchemaSession(const config::ConfigBase* session) const; - bool isAllowMismatchedParquetSchemas() const; - bool isAllowMismatchedParquetSchemasSession( + bool isAllowMismatchedCudfHiveSchemas() const; + bool isAllowMismatchedCudfHiveSchemasSession( const config::ConfigBase* session) const; cudf::data_type timestampType() const; @@ -162,4 +162,4 @@ class ParquetConfig { private: std::shared_ptr config_; }; -} // namespace facebook::velox::cudf_velox::connector::parquet +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/hive/CudfHiveConnector.cpp b/velox/experimental/cudf/connectors/hive/CudfHiveConnector.cpp new file mode 100644 index 000000000000..4806c196779f --- /dev/null +++ b/velox/experimental/cudf/connectors/hive/CudfHiveConnector.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnector.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveDataSource.h" +#include "velox/experimental/cudf/exec/ToCudf.h" + +#include "velox/connectors/hive/HiveDataSource.h" + +namespace facebook::velox::cudf_velox::connector::hive { + +using namespace facebook::velox::connector; + +CudfHiveConnector::CudfHiveConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* executor) + : ::facebook::velox::connector::hive::HiveConnector(id, config, executor), + cudfHiveConfig_(std::make_shared(config)) { + LOG(INFO) << "cuDF Hive connector created"; +} + +std::unique_ptr CudfHiveConnector::createDataSource( + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const ColumnHandleMap& columnHandles, + ConnectorQueryCtx* connectorQueryCtx) { + // If it's parquet then return CudfHiveDataSource + // If it's not parquet then return HiveDataSource + // TODO (dm): Make this ^^^ happen + // Problem: this information is in split, not table handle + + if (cudfIsRegistered()) { + return std::make_unique( + outputType, + tableHandle, + columnHandles, + ioExecutor_, + connectorQueryCtx, + cudfHiveConfig_); + } + + return std::make_unique<::facebook::velox::connector::hive::HiveDataSource>( + outputType, + tableHandle, + columnHandles, + &fileHandleFactory_, + ioExecutor_, + connectorQueryCtx, + hiveConfig_); +} + +// TODO (dm): Re-add data sink + +std::shared_ptr CudfHiveConnectorFactory::newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor, + folly::Executor* cpuExecutor) { + return std::make_shared(id, config, ioExecutor); +} + +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/parquet/ParquetConnector.h b/velox/experimental/cudf/connectors/hive/CudfHiveConnector.h similarity index 54% rename from velox/experimental/cudf/connectors/parquet/ParquetConnector.h rename to velox/experimental/cudf/connectors/hive/CudfHiveConnector.h index fc5b384c374b..fbb5929f3f42 100644 --- a/velox/experimental/cudf/connectors/parquet/ParquetConnector.h +++ b/velox/experimental/cudf/connectors/hive/CudfHiveConnector.h @@ -16,25 +16,23 @@ #pragma once -#include "velox/experimental/cudf/connectors/parquet/ParquetConfig.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetDataSink.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetDataSource.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConfig.h" -#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/HiveConnector.h" #include #include #include -namespace facebook::velox::cudf_velox::connector::parquet { +namespace facebook::velox::cudf_velox::connector::hive { using namespace facebook::velox::connector; using namespace facebook::velox::config; -class ParquetConnector final : public Connector { +class CudfHiveConnector final + : public ::facebook::velox::connector::hive::HiveConnector { public: - ParquetConnector( + CudfHiveConnector( const std::string& id, std::shared_ptr config, folly::Executor* executor); @@ -45,29 +43,26 @@ class ParquetConnector final : public Connector { const ColumnHandleMap& columnHandles, ConnectorQueryCtx* connectorQueryCtx) override final; - std::unique_ptr createDataSink( - RowTypePtr inputType, - ConnectorInsertTableHandlePtr connectorInsertTableHandle, - ConnectorQueryCtx* connectorQueryCtx, - CommitStrategy commitStrategy) override final; - - folly::Executor* executor() const override { - return executor_; + bool canAddDynamicFilter() const override { + return false; } + // TODO (dm): Re-add data sink + protected: - const std::shared_ptr parquetConfig_; - folly::Executor* executor_; + // TODO (dm): rename parquetconfig + const std::shared_ptr cudfHiveConfig_; }; -class ParquetConnectorFactory : public ConnectorFactory { +class CudfHiveConnectorFactory + : public ::facebook::velox::connector::hive::HiveConnectorFactory { public: - static constexpr const char* kParquetConnectorName = "parquet"; - - ParquetConnectorFactory() : ConnectorFactory(kParquetConnectorName) {} + CudfHiveConnectorFactory() + : ::facebook::velox::connector::hive::HiveConnectorFactory() {} - explicit ParquetConnectorFactory(const char* connectorName) - : ConnectorFactory(connectorName) {} + explicit CudfHiveConnectorFactory(const char* connectorName) + : ::facebook::velox::connector::hive::HiveConnectorFactory( + connectorName) {} std::shared_ptr newConnector( const std::string& id, @@ -76,4 +71,4 @@ class ParquetConnectorFactory : public ConnectorFactory { folly::Executor* cpuExecutor = nullptr) override; }; -} // namespace facebook::velox::cudf_velox::connector::parquet +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.cpp b/velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.cpp new file mode 100644 index 000000000000..c4f04c6be0aa --- /dev/null +++ b/velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.h" + +#include + +#include + +namespace facebook::velox::cudf_velox::connector::hive { + +namespace { +std::string stripFilePrefix(const std::string& targetPath) { + const std::string prefix = "file:"; + if (targetPath.rfind(prefix, 0) == 0) { + return targetPath.substr(prefix.length()); + } + return targetPath; +} +} // namespace + +std::string CudfHiveConnectorSplit::toString() const { + return fmt::format("CudfHive: {}", filePath); +} + +std::string CudfHiveConnectorSplit::getFileName() const { + const auto i = filePath.rfind('/'); + return i == std::string::npos ? filePath : filePath.substr(i + 1); +} + +CudfHiveConnectorSplit::CudfHiveConnectorSplit( + const std::string& connectorId, + const std::string& _filePath, + int64_t _splitWeight) + : facebook::velox::connector::ConnectorSplit(connectorId, _splitWeight), + filePath(stripFilePrefix(_filePath)), + cudfSourceInfo(std::make_unique(filePath)) {} + +// static +std::shared_ptr CudfHiveConnectorSplit::create( + const folly::dynamic& obj) { + const auto connectorId = obj["connectorId"].asString(); + const auto splitWeight = obj["splitWeight"].asInt(); + const auto filePath = obj["filePath"].asString(); + + return std::make_shared( + connectorId, filePath, splitWeight); +} + +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/parquet/ParquetConnectorSplit.h b/velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.h similarity index 63% rename from velox/experimental/cudf/connectors/parquet/ParquetConnectorSplit.h rename to velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.h index 72e9ba7a572e..6ae350a474ab 100644 --- a/velox/experimental/cudf/connectors/parquet/ParquetConnectorSplit.h +++ b/velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.h @@ -19,55 +19,57 @@ #include "velox/connectors/Connector.h" #include "velox/dwio/common/Options.h" -#include +namespace cudf { +namespace io { +struct source_info; +} +} // namespace cudf +#include #include -namespace facebook::velox::cudf_velox::connector::parquet { +namespace facebook::velox::cudf_velox::connector::hive { -struct ParquetConnectorSplit +struct CudfHiveConnectorSplit : public facebook::velox::connector::ConnectorSplit { const std::string filePath; const facebook::velox::dwio::common::FileFormat fileFormat{ facebook::velox::dwio::common::FileFormat::PARQUET}; - const cudf::io::source_info cudfSourceInfo; + const std::unique_ptr cudfSourceInfo; - ParquetConnectorSplit( + CudfHiveConnectorSplit( const std::string& connectorId, const std::string& _filePath, - int64_t _splitWeight = 0) - : facebook::velox::connector::ConnectorSplit(connectorId, _splitWeight), - filePath(_filePath), - cudfSourceInfo({filePath}) {} + int64_t _splitWeight = 0); std::string toString() const override; std::string getFileName() const; const cudf::io::source_info& getCudfSourceInfo() const { - return cudfSourceInfo; + return *cudfSourceInfo; } - static std::shared_ptr create( + static std::shared_ptr create( const folly::dynamic& obj); }; -class ParquetConnectorSplitBuilder { +class CudfHiveConnectorSplitBuilder { public: - explicit ParquetConnectorSplitBuilder(std::string filePath) + explicit CudfHiveConnectorSplitBuilder(std::string filePath) : filePath_{std::move(filePath)} {} - ParquetConnectorSplitBuilder& splitWeight(int64_t splitWeight) { + CudfHiveConnectorSplitBuilder& splitWeight(int64_t splitWeight) { splitWeight_ = splitWeight; return *this; } - ParquetConnectorSplitBuilder& connectorId(const std::string& connectorId) { + CudfHiveConnectorSplitBuilder& connectorId(const std::string& connectorId) { connectorId_ = connectorId; return *this; } - std::shared_ptr build() const { - return std::make_shared( + std::shared_ptr build() const { + return std::make_shared( connectorId_, filePath_, splitWeight_); } @@ -77,4 +79,4 @@ class ParquetConnectorSplitBuilder { int64_t splitWeight_{0}; }; -} // namespace facebook::velox::cudf_velox::connector::parquet +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/parquet/ParquetDataSink.cpp b/velox/experimental/cudf/connectors/hive/CudfHiveDataSink.cpp similarity index 84% rename from velox/experimental/cudf/connectors/parquet/ParquetDataSink.cpp rename to velox/experimental/cudf/connectors/hive/CudfHiveDataSink.cpp index 98f66adde6fa..7c6d06163c4a 100644 --- a/velox/experimental/cudf/connectors/parquet/ParquetDataSink.cpp +++ b/velox/experimental/cudf/connectors/hive/CudfHiveDataSink.cpp @@ -14,9 +14,9 @@ * limitations under the License. */ -#include "velox/experimental/cudf/connectors/parquet/ParquetConfig.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetDataSink.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConfig.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveDataSink.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.h" #include "velox/experimental/cudf/exec/Utilities.h" #include "velox/experimental/cudf/exec/VeloxCudfInterop.h" #include "velox/experimental/cudf/vector/CudfVector.h" @@ -39,7 +39,7 @@ using facebook::velox::common::testutil::TestValue; -namespace facebook::velox::cudf_velox::connector::parquet { +namespace facebook::velox::cudf_velox::connector::hive { namespace { @@ -58,8 +58,8 @@ std::unordered_map invertMap(const std::unordered_map& mapping) { return inverted; } -uint64_t getFinishTimeSliceLimitMsFromParquetConfig( - const std::shared_ptr& config, +uint64_t getFinishTimeSliceLimitMsFromCudfHiveConfig( + const std::shared_ptr& config, const config::ConfigBase* sessions) { const uint64_t flushTimeSliceLimitMsFromConfig = config->sortWriterFinishTimeSliceLimitMs(sessions); @@ -122,12 +122,12 @@ LocationHandle::TableType LocationHandle::tableTypeFromName( return kNameTableTypes.at(name); } -ParquetDataSink::ParquetDataSink( +CudfHiveDataSink::CudfHiveDataSink( RowTypePtr inputType, - std::shared_ptr insertTableHandle, + std::shared_ptr insertTableHandle, const ConnectorQueryCtx* connectorQueryCtx, CommitStrategy commitStrategy, - const std::shared_ptr& parquetConfig) + const std::shared_ptr& parquetConfig) : inputType_(std::move(inputType)), insertTableHandle_(std::move(insertTableHandle)), connectorQueryCtx_(connectorQueryCtx), @@ -135,16 +135,16 @@ ParquetDataSink::ParquetDataSink( parquetConfig_(parquetConfig), spillConfig_(connectorQueryCtx->spillConfig()), sortWriterFinishTimeSliceLimitMs_( - getFinishTimeSliceLimitMsFromParquetConfig( + getFinishTimeSliceLimitMsFromCudfHiveConfig( parquetConfig_, connectorQueryCtx->sessionProperties())) { VELOX_USER_CHECK( (commitStrategy_ == CommitStrategy::kNoCommit) || (commitStrategy_ == CommitStrategy::kTaskCommit), "Unsupported commit strategy: {}", - commitStrategyToString(commitStrategy_)); + CommitStrategyName::toName(commitStrategy_)); - const auto& writerOptions = dynamic_cast( + const auto& writerOptions = dynamic_cast( insertTableHandle_->writerOptions().get()); if (writerOptions != nullptr) { @@ -152,7 +152,7 @@ ParquetDataSink::ParquetDataSink( } } -void ParquetDataSink::appendData(RowVectorPtr input) { +void CudfHiveDataSink::appendData(RowVectorPtr input) { checkRunning(); // Convert the input RowVectorPtr to cudf::table @@ -174,7 +174,7 @@ void ParquetDataSink::appendData(RowVectorPtr input) { } std::unique_ptr -ParquetDataSink::createCudfWriter(cudf::table_view cudfTable) { +CudfHiveDataSink::createCudfWriter(cudf::table_view cudfTable) { // Create a table_input_metadata from the input auto tableInputMetadata = createCudfTableInputMetadata(cudfTable); @@ -188,8 +188,8 @@ ParquetDataSink::createCudfWriter(cudf::table_view cudfTable) { ? fmt::format("{}{}", makeUuid(), ".parquet") : locationHandle->targetFileName(); - auto writerParameters = ParquetWriterParameters( - ParquetWriterParameters::UpdateMode::kNew, + auto writerParameters = CudfHiveWriterParameters( + CudfHiveWriterParameters::UpdateMode::kNew, targetFileName, locationHandle->targetPath()); @@ -210,7 +210,7 @@ ParquetDataSink::createCudfWriter(cudf::table_view cudfTable) { .compression(compressionKind) .build(); - const auto& writerOptions = dynamic_cast( + const auto& writerOptions = dynamic_cast( insertTableHandle_->writerOptions().get()); // If non-null writerOptions were passed, pass them to the chunked parquet @@ -260,29 +260,30 @@ ParquetDataSink::createCudfWriter(cudf::table_view cudfTable) { return std::make_unique(cudfWriterOptions); } -cudf::io::table_input_metadata ParquetDataSink::createCudfTableInputMetadata( +cudf::io::table_input_metadata CudfHiveDataSink::createCudfTableInputMetadata( cudf::table_view cudfTable) { auto tableInputMetadata = cudf::io::table_input_metadata(cudfTable); auto inputColumns = insertTableHandle_->inputColumns(); // Check if equal number of columns in the input and - // ParquetInsertTableHandle + // CudfHiveInsertTableHandle VELOX_CHECK_EQ( tableInputMetadata.column_metadata.size(), inputColumns.size(), - "Unequal number of columns in the input and ParquetInsertTableHandle"); + "Unequal number of columns in the input and CudfHiveInsertTableHandle"); - std::function + std::function setColumnName = [&](cudf::io::column_in_metadata& colMeta, - const ParquetColumnHandle& columnHandle) { + const CudfHiveColumnHandle& columnHandle) { // Check if equal number of children const auto& childrenHandles = columnHandle.children(); - // Warn if the mismatch in the number of child cols in Parquet + // Warn if the mismatch in the number of child cols in CudfHive // table_metadata and columnHandles if (colMeta.num_children() != childrenHandles.size()) { LOG(WARNING) << fmt::format( - "({} vs {}): Unequal number of child columns in Parquet table_metadata and ColumnHandles", + "({} vs {}): Unequal number of child columns in CudfHive table_metadata and ColumnHandles", colMeta.num_children(), childrenHandles.size()); } @@ -305,7 +306,7 @@ cudf::io::table_input_metadata ParquetDataSink::createCudfTableInputMetadata( return tableInputMetadata; } -std::string ParquetDataSink::stateString(State state) { +std::string CudfHiveDataSink::stateString(State state) { switch (state) { case State::kRunning: return "RUNNING"; @@ -320,7 +321,7 @@ std::string ParquetDataSink::stateString(State state) { } } -DataSink::Stats ParquetDataSink::stats() const { +DataSink::Stats CudfHiveDataSink::stats() const { Stats stats; if (state_ == State::kAborted) { return stats; @@ -349,13 +350,13 @@ DataSink::Stats ParquetDataSink::stats() const { return stats; } -void ParquetDataSink::setState(State newState) { +void CudfHiveDataSink::setState(State newState) { checkStateTransition(state_, newState); state_ = newState; } /// Validates the state transition from 'oldState' to 'newState'. -void ParquetDataSink::checkStateTransition(State oldState, State newState) { +void CudfHiveDataSink::checkStateTransition(State oldState, State newState) { switch (oldState) { case State::kRunning: if (newState == State::kAborted || newState == State::kFinishing) { @@ -378,14 +379,14 @@ void ParquetDataSink::checkStateTransition(State oldState, State newState) { VELOX_FAIL("Unexpected state transition from {} to {}", oldState, newState); } -bool ParquetDataSink::finish() { - VELOX_CHECK_NOT_NULL(writer_, "ParquetDataSink has no writer"); +bool CudfHiveDataSink::finish() { + VELOX_CHECK_NOT_NULL(writer_, "CudfHiveDataSink has no writer"); setState(State::kFinishing); return true; } -std::vector ParquetDataSink::close() { +std::vector CudfHiveDataSink::close() { setState(State::kClosed); closeInternal(); @@ -413,18 +414,18 @@ std::vector ParquetDataSink::close() { return partitionUpdates; } -void ParquetDataSink::abort() { +void CudfHiveDataSink::abort() { setState(State::kAborted); closeInternal(); } -void ParquetDataSink::closeInternal() { +void CudfHiveDataSink::closeInternal() { VELOX_CHECK_NE(state_, State::kRunning); VELOX_CHECK_NE(state_, State::kFinishing); - VELOX_CHECK_NOT_NULL(writer_, "ParquetDataSink has no writer"); + VELOX_CHECK_NOT_NULL(writer_, "CudfHiveDataSink has no writer"); TestValue::adjust( - "facebook::velox::connector::parquet::ParquetDataSink::closeInternal", + "facebook::velox::connector::hive::CudfHiveDataSink::closeInternal", this); // Close cudf writer @@ -434,14 +435,14 @@ void ParquetDataSink::closeInternal() { writer_.reset(); } -std::shared_ptr ParquetDataSink::createWriterPool() { +std::shared_ptr CudfHiveDataSink::createWriterPool() { auto* connectorPool = connectorQueryCtx_->connectorMemoryPool(); return connectorPool->addAggregateChild( fmt::format("{}.{}", connectorPool->name(), "parquet-writer")); } -void ParquetDataSink::makeWriterOptions( - ParquetWriterParameters writerParameters) { +void CudfHiveDataSink::makeWriterOptions( + CudfHiveWriterParameters writerParameters) { auto writerPool = createWriterPool(); auto sinkPool = createSinkPool(writerPool); std::shared_ptr sortPool{nullptr}; @@ -449,7 +450,7 @@ void ParquetDataSink::makeWriterOptions( sortPool = createSortPool(writerPool); } - writerInfo_ = std::make_shared( + writerInfo_ = std::make_shared( std::move(writerParameters), std::move(writerPool), std::move(sinkPool), @@ -461,7 +462,7 @@ void ParquetDataSink::makeWriterOptions( // or allocate a new one. auto options = insertTableHandle_->writerOptions(); if (!options) { - options = std::make_unique(); + options = std::make_unique(); } const auto* connectorSessionProperties = @@ -483,9 +484,9 @@ void ParquetDataSink::makeWriterOptions( connectorQueryCtx_->adjustTimestampToTimezone(); } -folly::dynamic ParquetInsertTableHandle::serialize() const { +folly::dynamic CudfHiveInsertTableHandle::serialize() const { folly::dynamic obj = folly::dynamic::object; - obj["name"] = "ParquetInsertTableHandle"; + obj["name"] = "CudfHiveInsertTableHandle"; folly::dynamic arr = folly::dynamic::array; for (const auto& ic : inputColumns_) { arr.push_back(ic->serialize()); @@ -502,10 +503,10 @@ folly::dynamic ParquetInsertTableHandle::serialize() const { return obj; } -ParquetInsertTableHandlePtr ParquetInsertTableHandle::create( +CudfHiveInsertTableHandlePtr CudfHiveInsertTableHandle::create( const folly::dynamic& obj) { auto inputColumns = - ISerializable::deserialize>( + ISerializable::deserialize>( obj["inputColumns"]); auto locationHandle = ISerializable::deserialize(obj["locationHandle"]); @@ -518,13 +519,14 @@ ParquetInsertTableHandlePtr ParquetInsertTableHandle::create( for (const auto& pair : obj["serdeParameters"].items()) { serdeParameters.emplace(pair.first.asString(), pair.second.asString()); } - return std::make_shared( + return std::make_shared( inputColumns, locationHandle, compressionKind, serdeParameters); } -std::string ParquetInsertTableHandle::toString() const { +std::string CudfHiveInsertTableHandle::toString() const { std::ostringstream out; - out << "ParquetInsertTableHandle [" << dwio::common::toString(storageFormat_); + out << "CudfHiveInsertTableHandle [" + << dwio::common::toString(storageFormat_); if (compressionKind_.has_value()) { out << " " << common::compressionKindToString(compressionKind_.value()); } else { @@ -540,9 +542,9 @@ std::string ParquetInsertTableHandle::toString() const { return out.str(); } -void ParquetInsertTableHandle::registerSerDe() { +void CudfHiveInsertTableHandle::registerSerDe() { auto& registry = DeserializationRegistryForSharedPtr(); - registry.Register("HiveInsertTableHandle", ParquetInsertTableHandle::create); + registry.Register("HiveInsertTableHandle", CudfHiveInsertTableHandle::create); } std::string LocationHandle::toString() const { @@ -566,4 +568,4 @@ LocationHandlePtr LocationHandle::create(const folly::dynamic& obj) { return std::make_shared(targetPath, tableType); } -} // namespace facebook::velox::cudf_velox::connector::parquet +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/parquet/ParquetDataSink.h b/velox/experimental/cudf/connectors/hive/CudfHiveDataSink.h similarity index 81% rename from velox/experimental/cudf/connectors/parquet/ParquetDataSink.h rename to velox/experimental/cudf/connectors/hive/CudfHiveDataSink.h index 872ee9f47fac..02984124c247 100644 --- a/velox/experimental/cudf/connectors/parquet/ParquetDataSink.h +++ b/velox/experimental/cudf/connectors/hive/CudfHiveDataSink.h @@ -15,9 +15,9 @@ */ #pragma once -#include "velox/experimental/cudf/connectors/parquet/ParquetConfig.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h" -#include "velox/experimental/cudf/connectors/parquet/WriterOptions.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConfig.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.h" +#include "velox/experimental/cudf/connectors/hive/WriterOptions.h" #include "velox/common/compression/Compression.h" #include "velox/connectors/Connector.h" @@ -31,14 +31,14 @@ #include #include -namespace facebook::velox::cudf_velox::connector::parquet { +namespace facebook::velox::cudf_velox::connector::hive { using namespace facebook::velox::connector; class LocationHandle; using LocationHandlePtr = std::shared_ptr; -/// Location related properties of the Parquet table to be written. +/// Location related properties of the CudfHive table to be written. class LocationHandle : public ISerializable { public: enum class TableType { @@ -88,7 +88,7 @@ class LocationHandle : public ISerializable { }; /// Parameters for Hive writers. -class ParquetWriterParameters { +class CudfHiveWriterParameters { public: enum class UpdateMode { kNew, // Write files to a new directory. @@ -105,7 +105,7 @@ class ParquetWriterParameters { /// @param writeDirectory The temporary directory that a running writer writes /// to. If a running writer writes directory to the target directory, set /// writeDirectory to targetDirectory by default. - ParquetWriterParameters( + CudfHiveWriterParameters( UpdateMode updateMode, std::string targetFileName, std::string targetDirectory, @@ -155,9 +155,9 @@ class ParquetWriterParameters { const std::string writeDirectory_; }; -struct ParquetWriterInfo { - ParquetWriterInfo( - ParquetWriterParameters parameters, +struct CudfHiveWriterInfo { + CudfHiveWriterInfo( + CudfHiveWriterParameters parameters, std::shared_ptr _writerPool, std::shared_ptr _sinkPool, std::shared_ptr _sortPool) @@ -168,7 +168,7 @@ struct ParquetWriterInfo { sinkPool(std::move(_sinkPool)), sortPool(std::move(_sortPool)) {} - const ParquetWriterParameters writerParameters; + const CudfHiveWriterParameters writerParameters; const std::unique_ptr> nonReclaimableSectionHolder; /// Collects the spill stats from sort writer if the spilling has been /// triggered. @@ -180,14 +180,14 @@ struct ParquetWriterInfo { int64_t inputSizeInBytes = 0; }; -class ParquetInsertTableHandle; -using ParquetInsertTableHandlePtr = std::shared_ptr; +class CudfHiveInsertTableHandle; +using CudfHiveInsertTableHandlePtr = std::shared_ptr; -/// Represents a request for Parquet write. -class ParquetInsertTableHandle : public ConnectorInsertTableHandle { +/// Represents a request for CudfHive write. +class CudfHiveInsertTableHandle : public ConnectorInsertTableHandle { public: - ParquetInsertTableHandle( - std::vector> inputColumns, + CudfHiveInsertTableHandle( + std::vector> inputColumns, std::shared_ptr locationHandle, std::optional compressionKind = {}, const std::unordered_map& serdeParameters = {}, @@ -207,13 +207,13 @@ class ParquetInsertTableHandle : public ConnectorInsertTableHandle { compressionKind.value() == common::CompressionKind_SNAPPY or compressionKind.value() == common::CompressionKind_LZ4 or compressionKind.value() == common::CompressionKind_ZSTD, - "Parquet DataSink only supports NONE, SNAPPY, LZ4, and ZSTD compressions."); + "CudfHive DataSink only supports NONE, SNAPPY, LZ4, and ZSTD compressions."); } } - virtual ~ParquetInsertTableHandle() = default; + virtual ~CudfHiveInsertTableHandle() = default; - const std::vector>& inputColumns() + const std::vector>& inputColumns() const { return inputColumns_; } @@ -243,20 +243,20 @@ class ParquetInsertTableHandle : public ConnectorInsertTableHandle { } bool isExistingTable() const { - return false; // This is always false as cudf's Parquet writer doesn't yet - // support updating existing Parquet files + return false; // This is always false as cudf's CudfHive writer doesn't yet + // support updating existing CudfHive files } folly::dynamic serialize() const override; - static ParquetInsertTableHandlePtr create(const folly::dynamic& obj); + static CudfHiveInsertTableHandlePtr create(const folly::dynamic& obj); static void registerSerDe(); std::string toString() const override; private: - const std::vector> inputColumns_; + const std::vector> inputColumns_; const std::shared_ptr locationHandle_; const std::optional compressionKind_; const dwio::common::FileFormat storageFormat_ = @@ -265,7 +265,7 @@ class ParquetInsertTableHandle : public ConnectorInsertTableHandle { const std::shared_ptr writerOptions_; }; -class ParquetDataSink : public DataSink { +class CudfHiveDataSink : public DataSink { public: /// The list of runtime stats reported by parquet data sink static constexpr const char* kEarlyFlushedRawBytes = "earlyFlushedRawBytes"; @@ -284,12 +284,12 @@ class ParquetDataSink : public DataSink { }; static std::string stateString(State state); - ParquetDataSink( + CudfHiveDataSink( RowTypePtr inputType, - std::shared_ptr insertTableHandle, + std::shared_ptr insertTableHandle, const ConnectorQueryCtx* connectorQueryCtx, CommitStrategy commitStrategy, - const std::shared_ptr& parquetConfig); + const std::shared_ptr& parquetConfig); void appendData(RowVectorPtr input) override; @@ -327,17 +327,18 @@ class ParquetDataSink : public DataSink { } FOLLY_ALWAYS_INLINE void checkRunning() const { - VELOX_CHECK_EQ(state_, State::kRunning, "Parquet data sink is not running"); + VELOX_CHECK_EQ( + state_, State::kRunning, "CudfHive data sink is not running"); } void closeInternal(); - void makeWriterOptions(ParquetWriterParameters writerParameters); + void makeWriterOptions(CudfHiveWriterParameters writerParameters); const RowTypePtr inputType_; - const std::shared_ptr insertTableHandle_; + const std::shared_ptr insertTableHandle_; const ConnectorQueryCtx* const connectorQueryCtx_; const CommitStrategy commitStrategy_; - const std::shared_ptr parquetConfig_; + const std::shared_ptr parquetConfig_; const common::SpillConfig* const spillConfig_; const uint64_t sortWriterFinishTimeSliceLimitMs_{0}; State state_{State::kRunning}; @@ -348,7 +349,7 @@ class ParquetDataSink : public DataSink { std::vector sortingColumns_; - std::shared_ptr writerInfo_; + std::shared_ptr writerInfo_; // IO statistics collected for writer. std::shared_ptr ioStats_; @@ -356,21 +357,21 @@ class ParquetDataSink : public DataSink { FOLLY_ALWAYS_INLINE std::ostream& operator<<( std::ostream& os, - ParquetDataSink::State state) { - os << ParquetDataSink::stateString(state); + CudfHiveDataSink::State state) { + os << CudfHiveDataSink::stateString(state); return os; } -} // namespace facebook::velox::cudf_velox::connector::parquet +} // namespace facebook::velox::cudf_velox::connector::hive template <> struct fmt::formatter< - facebook::velox::cudf_velox::connector::parquet::ParquetDataSink::State> + facebook::velox::cudf_velox::connector::hive::CudfHiveDataSink::State> : formatter { auto format( - facebook::velox::cudf_velox::connector::parquet::ParquetDataSink::State s, + facebook::velox::cudf_velox::connector::hive::CudfHiveDataSink::State s, format_context& ctx) const { return formatter::format( - facebook::velox::cudf_velox::connector::parquet::ParquetDataSink:: + facebook::velox::cudf_velox::connector::hive::CudfHiveDataSink:: stateString(s), ctx); } @@ -378,11 +379,10 @@ struct fmt::formatter< template <> struct fmt::formatter< - facebook::velox::cudf_velox::connector::parquet::LocationHandle::TableType> + facebook::velox::cudf_velox::connector::hive::LocationHandle::TableType> : formatter { auto format( - facebook::velox::cudf_velox::connector::parquet::LocationHandle::TableType - s, + facebook::velox::cudf_velox::connector::hive::LocationHandle::TableType s, format_context& ctx) const { return formatter::format(static_cast(s), ctx); } diff --git a/velox/experimental/cudf/connectors/parquet/ParquetDataSource.cpp b/velox/experimental/cudf/connectors/hive/CudfHiveDataSource.cpp similarity index 75% rename from velox/experimental/cudf/connectors/parquet/ParquetDataSource.cpp rename to velox/experimental/cudf/connectors/hive/CudfHiveDataSource.cpp index 74a0d940fb26..1426c696e74c 100644 --- a/velox/experimental/cudf/connectors/parquet/ParquetDataSource.cpp +++ b/velox/experimental/cudf/connectors/hive/CudfHiveDataSource.cpp @@ -14,10 +14,11 @@ * limitations under the License. */ -#include "velox/experimental/cudf/connectors/parquet/ParquetConfig.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetConnectorSplit.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetDataSource.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConfig.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveDataSource.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.h" +#include "velox/experimental/cudf/exec/ExpressionEvaluator.h" #include "velox/experimental/cudf/exec/ToCudf.h" #include "velox/experimental/cudf/exec/Utilities.h" #include "velox/experimental/cudf/exec/VeloxCudfInterop.h" @@ -25,6 +26,7 @@ #include "velox/common/time/Timer.h" #include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/TableHandle.h" #include "velox/expression/FieldReference.h" #include @@ -40,19 +42,20 @@ #include #include -namespace facebook::velox::cudf_velox::connector::parquet { +namespace facebook::velox::cudf_velox::connector::hive { using namespace facebook::velox::connector; +using namespace facebook::velox::connector::hive; -ParquetDataSource::ParquetDataSource( +CudfHiveDataSource::CudfHiveDataSource( const RowTypePtr& outputType, const ConnectorTableHandlePtr& tableHandle, const ColumnHandleMap& columnHandles, folly::Executor* executor, const ConnectorQueryCtx* connectorQueryCtx, - const std::shared_ptr& parquetConfig) + const std::shared_ptr& parquetConfig) : NvtxHelper( - nvtx3::rgb{80, 171, 241}, // Parquet blue, + nvtx3::rgb{80, 171, 241}, // CudfHive blue, std::nullopt, fmt::format("[{}]", tableHandle->name())), parquetConfig_(parquetConfig), @@ -70,30 +73,28 @@ ParquetDataSource::ParquetDataSource( "ColumnHandle is missing for output column: {}", outputName); - auto* handle = static_cast(it->second.get()); + auto* handle = static_cast(it->second.get()); readColumnNames_.emplace_back(handle->name()); } - // Dynamic cast tableHandle to ParquetTableHandle tableHandle_ = - std::dynamic_pointer_cast(tableHandle); + std::dynamic_pointer_cast(tableHandle); VELOX_CHECK_NOT_NULL( - tableHandle_, "TableHandle must be an instance of ParquetTableHandle"); + tableHandle_, "TableHandle must be an instance of HiveTableHandle"); // Create empty IOStats for later use ioStats_ = std::make_shared(); - // Create subfield filter - auto subfieldFilter = tableHandle_->subfieldFilterExpr(); - if (subfieldFilter) { - subfieldFilterExprSet_ = expressionEvaluator_->compile(subfieldFilter); + // Copy subfield filters + for (const auto& [k, v] : tableHandle_->subfieldFilters()) { + subfieldFilters_.emplace(k.clone(), v->clone()); // Add fields in the filter to the columns to read if not there - for (const auto& field : subfieldFilterExprSet_->distinctFields()) { + for (const auto& [field, _] : subfieldFilters_) { if (std::find( readColumnNames_.begin(), readColumnNames_.end(), - field->name()) == readColumnNames_.end()) { - readColumnNames_.push_back(field->name()); + field.toString()) == readColumnNames_.end()) { + readColumnNames_.push_back(field.toString()); } } } @@ -136,7 +137,7 @@ ParquetDataSource::ParquetDataSource( } } -std::optional ParquetDataSource::next( +std::optional CudfHiveDataSource::next( uint64_t /*size*/, velox::ContinueFuture& /* future */) { VELOX_NVTX_OPERATOR_FUNC_RANGE(); @@ -168,7 +169,7 @@ std::optional ParquetDataSource::next( // Launch host callback to calculate timing when scan completes cudaLaunchHostFunc( stream_.value(), - &ParquetDataSource::totalScanTimeCalculator, + &CudfHiveDataSource::totalScanTimeCalculator, callbackData); uint64_t filterTimeUs{0}; @@ -215,6 +216,9 @@ std::optional ParquetDataSource::next( cudfTable = std::make_unique(std::move(originalColumns)); } + // TODO (dm): Should we only enable table scan if cudf is registered? + // Earlier we could enable cudf table scans without using other cudf operators + // We still can, but I'm wondering if this is the right thing to do auto output = cudfIsRegistered() ? std::make_shared( pool_, outputType_, nRows, std::move(cudfTable), stream_) @@ -233,7 +237,7 @@ std::optional ParquetDataSource::next( return output; } -void ParquetDataSource::totalScanTimeCalculator(void* userData) { +void CudfHiveDataSource::totalScanTimeCalculator(void* userData) { TotalScanTimeCallbackData* data = static_cast(userData); @@ -250,20 +254,30 @@ void ParquetDataSource::totalScanTimeCalculator(void* userData) { delete data; } -void ParquetDataSource::addSplit(std::shared_ptr split) { +void CudfHiveDataSource::addSplit(std::shared_ptr split) { split_ = [&]() { - // Dynamic cast split to `ParquetConnectorSplit` - if (std::dynamic_pointer_cast(split)) { - return std::dynamic_pointer_cast(split); - // Convert `HiveConnectorSplit` to `ParquetConnectorSplit` + // Dynamic cast split to `CudfHiveConnectorSplit` + if (std::dynamic_pointer_cast(split)) { + return std::dynamic_pointer_cast(split); + // Convert `HiveConnectorSplit` to `CudfHiveConnectorSplit` } else if (std::dynamic_pointer_cast(split)) { const auto hiveSplit = std::dynamic_pointer_cast(split); VELOX_CHECK_EQ( hiveSplit->fileFormat, dwio::common::FileFormat::PARQUET, - "Unsupported file format for conversion from HiveConnectorSplit to cuDF ParquetConnectorSplit"); - return ParquetConnectorSplitBuilder(hiveSplit->filePath) + "Unsupported file format for conversion from HiveConnectorSplit to CudfHiveConnectorSplit"); + VELOX_CHECK_EQ( + hiveSplit->start, + 0, + "CudfHiveDataSource cannot process splits with non-zero offset"); + // Remove "file:" prefix from the file path if present + std::string cleanedPath = hiveSplit->filePath; + constexpr std::string_view kFilePrefix = "file:"; + if (cleanedPath.compare(0, kFilePrefix.size(), kFilePrefix) == 0) { + cleanedPath = cleanedPath.substr(kFilePrefix.size()); + } + return CudfHiveConnectorSplitBuilder(cleanedPath) .connectorId(hiveSplit->connectorId) .splitWeight(hiveSplit->splitWeight) .build(); @@ -296,7 +310,7 @@ void ParquetDataSource::addSplit(std::shared_ptr split) { } std::unique_ptr -ParquetDataSource::createSplitReader() { +CudfHiveDataSource::createSplitReader() { // Reader options auto readerOptions = cudf::io::parquet_reader_options::builder(split_->getCudfSourceInfo()) @@ -304,7 +318,7 @@ ParquetDataSource::createSplitReader() { .use_pandas_metadata(parquetConfig_->isUsePandasMetadata()) .use_arrow_schema(parquetConfig_->isUseArrowSchema()) .allow_mismatched_pq_schemas( - parquetConfig_->isAllowMismatchedParquetSchemas()) + parquetConfig_->isAllowMismatchedCudfHiveSchemas()) .timestamp_type(parquetConfig_->timestampType()) .build(); @@ -313,39 +327,30 @@ ParquetDataSource::createSplitReader() { readerOptions.set_num_rows(parquetConfig_->numRows().value()); } - if (subfieldFilterExprSet_) { - auto subfieldFilterExpr = subfieldFilterExprSet_->expr(0); - - // non-ast instructions in filter is not supported for SubFieldFilter. - // precomputeInstructions which are non-ast instructions should be empty. - std::vector precomputeInstructions; - - const RowTypePtr readerFilterType_ = [&] { + if (subfieldFilters_.size()) { + const RowTypePtr readerFilterType = [&] { if (tableHandle_->dataColumns()) { - std::vector new_names; - std::vector new_types; + std::vector newNames; + std::vector newTypes; for (const auto& name : readColumnNames_) { // Ensure all columns being read are available to the filter auto parsedType = tableHandle_->dataColumns()->findChild(name); - new_names.emplace_back(std::move(name)); - new_types.push_back(parsedType); + newNames.emplace_back(std::move(name)); + newTypes.push_back(parsedType); } - return ROW(std::move(new_names), std::move(new_types)); + return ROW(std::move(newNames), std::move(newTypes)); } else { return outputType_; } }(); - createAstTree( - subfieldFilterExpr, - subfieldTree_, - subfieldScalars_, - readerFilterType_, - precomputeInstructions); - VELOX_CHECK_EQ(precomputeInstructions.size(), 0); - readerOptions.set_filter(subfieldTree_.back()); + // Build a combined AST for all subfield filters. + auto const& combinedExpr = createAstFromSubfieldFilters( + subfieldFilters_, subfieldTree_, subfieldScalars_, readerFilterType); + + readerOptions.set_filter(combinedExpr); } // Set column projection if needed @@ -363,24 +368,24 @@ ParquetDataSource::createSplitReader() { cudf::get_current_device_resource_ref()); } -void ParquetDataSource::resetSplit() { +void CudfHiveDataSource::resetSplit() { split_.reset(); splitReader_.reset(); columnNames_.clear(); } -std::unordered_map -ParquetDataSource::runtimeStats() { - auto res = runtimeStats_.toMap(); +std::unordered_map +CudfHiveDataSource::getRuntimeStats() { + auto res = runtimeStats_.toRuntimeMetricMap(); res.insert({ {"totalScanTime", - RuntimeCounter(ioStats_->totalScanTime(), RuntimeCounter::Unit::kNanos)}, + RuntimeMetric(ioStats_->totalScanTime(), RuntimeCounter::Unit::kNanos)}, {"totalRemainingFilterTime", - RuntimeCounter( + RuntimeMetric( totalRemainingFilterTime_.load(std::memory_order_relaxed), RuntimeCounter::Unit::kNanos)}, }); return res; } -} // namespace facebook::velox::cudf_velox::connector::parquet +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/parquet/ParquetDataSource.h b/velox/experimental/cudf/connectors/hive/CudfHiveDataSource.h similarity index 79% rename from velox/experimental/cudf/connectors/parquet/ParquetDataSource.h rename to velox/experimental/cudf/connectors/hive/CudfHiveDataSource.h index c62f52120358..df44aa4e4125 100644 --- a/velox/experimental/cudf/connectors/parquet/ParquetDataSource.h +++ b/velox/experimental/cudf/connectors/hive/CudfHiveDataSource.h @@ -16,34 +16,34 @@ #pragma once -#include "velox/experimental/cudf/connectors/parquet/ParquetConfig.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetConnectorSplit.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConfig.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.h" #include "velox/experimental/cudf/exec/ExpressionEvaluator.h" #include "velox/experimental/cudf/exec/NvtxHelper.h" #include "velox/common/base/RandomUtil.h" #include "velox/common/io/IoStatistics.h" #include "velox/connectors/Connector.h" +#include "velox/connectors/hive/TableHandle.h" #include "velox/dwio/common/Statistics.h" #include "velox/type/Type.h" #include #include -namespace facebook::velox::cudf_velox::connector::parquet { +namespace facebook::velox::cudf_velox::connector::hive { using namespace facebook::velox::connector; -class ParquetDataSource : public DataSource, public NvtxHelper { +class CudfHiveDataSource : public DataSource, public NvtxHelper { public: - ParquetDataSource( + CudfHiveDataSource( const RowTypePtr& outputType, const ConnectorTableHandlePtr& tableHandle, const ColumnHandleMap& columnHandles, folly::Executor* executor, const ConnectorQueryCtx* connectorQueryCtx, - const std::shared_ptr& ParquetConfig); + const std::shared_ptr& CudfHiveConfig); void addSplit(std::shared_ptr split) override; @@ -51,7 +51,8 @@ class ParquetDataSource : public DataSource, public NvtxHelper { column_index_t /*outputChannel*/, const std::shared_ptr& /*filter*/) override { - VELOX_NYI("Dynamic filters not yet implemented by cudf::ParquetConnector."); + VELOX_NYI( + "Dynamic filters not yet implemented by cudf::CudfHiveConnector."); } std::optional next( @@ -62,11 +63,15 @@ class ParquetDataSource : public DataSource, public NvtxHelper { return completedRows_; } + const common::SubfieldFilters* getFilters() const override { + return &subfieldFilters_; + } + uint64_t getCompletedBytes() override { return completedBytes_; } - std::unordered_map runtimeStats() override; + std::unordered_map getRuntimeStats() override; private: // Create a cudf::io::chunked_parquet_reader with the given split. @@ -85,22 +90,23 @@ class ParquetDataSource : public DataSource, public NvtxHelper { } RowVectorPtr emptyOutput_; - std::shared_ptr split_; - std::shared_ptr tableHandle_; + std::shared_ptr split_; + std::shared_ptr + tableHandle_; - const std::shared_ptr parquetConfig_; + const std::shared_ptr parquetConfig_; folly::Executor* const executor_; const ConnectorQueryCtx* const connectorQueryCtx_; memory::MemoryPool* const pool_; - // cuDF Parquet reader stuff. + // cuDF CudfHive reader stuff. cudf::io::parquet_reader_options readerOptions_; std::unique_ptr splitReader_; rmm::cuda_stream_view stream_; - // Table column names read from the Parquet file + // Table column names read from the CudfHive file std::vector columnNames_; // Output type from file reader. This is different from outputType_ that it @@ -127,7 +133,7 @@ class ParquetDataSource : public DataSource, public NvtxHelper { // Expression evaluator for subfield filter. std::vector> subfieldScalars_; cudf::ast::tree subfieldTree_; - std::unique_ptr subfieldFilterExprSet_; + common::SubfieldFilters subfieldFilters_; dwio::common::RuntimeStatistics runtimeStats_; std::atomic totalRemainingFilterTime_{0}; @@ -142,4 +148,4 @@ class ParquetDataSource : public DataSource, public NvtxHelper { static void totalScanTimeCalculator(void* userData); }; -} // namespace facebook::velox::cudf_velox::connector::parquet +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/parquet/ParquetTableHandle.cpp b/velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.cpp similarity index 74% rename from velox/experimental/cudf/connectors/parquet/ParquetTableHandle.cpp rename to velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.cpp index dcb51868516d..64286506cf09 100644 --- a/velox/experimental/cudf/connectors/parquet/ParquetTableHandle.cpp +++ b/velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.cpp @@ -14,25 +14,25 @@ * limitations under the License. */ -#include "velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.h" #include "velox/connectors/Connector.h" #include "velox/type/Type.h" #include -namespace facebook::velox::cudf_velox::connector::parquet { +namespace facebook::velox::cudf_velox::connector::hive { using namespace facebook::velox::connector; -std::string ParquetColumnHandle::toString() const { +std::string CudfHiveColumnHandle::toString() const { std::ostringstream out; out << fmt::format( - "ParquetColumnHandle [name: {}, Type: {},", name_, type_->toString()); + "CudfHiveColumnHandle [name: {}, Type: {},", name_, type_->toString()); return out.str(); } -ParquetTableHandle::ParquetTableHandle( +CudfHiveTableHandle::CudfHiveTableHandle( std::string connectorId, const std::string& tableName, bool filterPushdownEnabled, @@ -46,7 +46,7 @@ ParquetTableHandle::ParquetTableHandle( remainingFilter_(remainingFilter), dataColumns_(dataColumns) {} -std::string ParquetTableHandle::toString() const { +std::string CudfHiveTableHandle::toString() const { std::stringstream out; out << "table: " << tableName_; if (dataColumns_) { @@ -55,10 +55,10 @@ std::string ParquetTableHandle::toString() const { return out.str(); } -ConnectorTableHandlePtr ParquetTableHandle::create( +ConnectorTableHandlePtr CudfHiveTableHandle::create( const folly::dynamic& obj, void* context) { - VELOX_NYI("ParquetTableHandle::create() not yet implemented"); + VELOX_NYI("CudfHiveTableHandle::create() not yet implemented"); } -} // namespace facebook::velox::cudf_velox::connector::parquet +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h b/velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.h similarity index 83% rename from velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h rename to velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.h index b8e1fae7b882..8d97db028d49 100644 --- a/velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h +++ b/velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.h @@ -26,19 +26,19 @@ #include #include -namespace facebook::velox::cudf_velox::connector::parquet { +namespace facebook::velox::cudf_velox::connector::hive { using namespace facebook::velox::connector; -// Parquet column handle only needs the column name (all columns are generated +// CudfHive column handle only needs the column name (all columns are generated // in the same way). -class ParquetColumnHandle : public ColumnHandle { +class CudfHiveColumnHandle : public ColumnHandle { public: - explicit ParquetColumnHandle( + explicit CudfHiveColumnHandle( const std::string& name, const TypePtr type, const cudf::data_type cudfDataType, - std::vector children = {}) + std::vector children = {}) : name_(name), type_(type), cudfDataType_(cudfDataType), @@ -56,7 +56,7 @@ class ParquetColumnHandle : public ColumnHandle { return cudfDataType_; } - const std::vector& children() const { + const std::vector& children() const { return children_; } @@ -66,12 +66,12 @@ class ParquetColumnHandle : public ColumnHandle { const std::string name_; const TypePtr type_; const cudf::data_type cudfDataType_; - const std::vector children_; + const std::vector children_; }; -class ParquetTableHandle : public ConnectorTableHandle { +class CudfHiveTableHandle : public ConnectorTableHandle { public: - ParquetTableHandle( + CudfHiveTableHandle( std::string connectorId, const std::string& tableName, bool filterPushdownEnabled, @@ -116,4 +116,4 @@ class ParquetTableHandle : public ConnectorTableHandle { const RowTypePtr dataColumns_; }; -} // namespace facebook::velox::cudf_velox::connector::parquet +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/parquet/WriterOptions.h b/velox/experimental/cudf/connectors/hive/WriterOptions.h similarity index 88% rename from velox/experimental/cudf/connectors/parquet/WriterOptions.h rename to velox/experimental/cudf/connectors/hive/WriterOptions.h index 09d9527ddaeb..389ad985e27e 100644 --- a/velox/experimental/cudf/connectors/parquet/WriterOptions.h +++ b/velox/experimental/cudf/connectors/hive/WriterOptions.h @@ -24,26 +24,26 @@ #include -namespace facebook::velox::cudf_velox::connector::parquet { +namespace facebook::velox::cudf_velox::connector::hive { using namespace cudf::io; /** * @brief Struct to 1:1 correspond with cudf::io::chunked_parquet_reader_options - * except sink_info and a few others which are provided to the ParquetDataSink + * except sink_info and a few others which are provided to the CudfHiveDataSink * from elsewhere. */ -struct ParquetWriterOptions +struct CudfHiveWriterOptions : public facebook::velox::dwio::common::WriterOptions { // Specify the level of statistics in the output file statistics_freq statsLevel = statistics_freq::STATISTICS_ROWGROUP; - // Parquet writer can write INT96 or TIMESTAMP_MICROS. Defaults to + // CudfHive writer can write INT96 or TIMESTAMP_MICROS. Defaults to // TIMESTAMPMICROS. If true then overrides any per-column setting in // Metadata. bool writeTimestampsAsInt96 = false; - // Parquet writer can write timestamps as UTC + // CudfHive writer can write timestamps as UTC // Defaults to true because libcudf timestamps are implicitly UTC bool writeTimestampsAsUTC = true; @@ -87,4 +87,4 @@ struct ParquetWriterOptions std::vector sortingColumns; }; -} // namespace facebook::velox::cudf_velox::connector::parquet +} // namespace facebook::velox::cudf_velox::connector::hive diff --git a/velox/experimental/cudf/connectors/parquet/ParquetConnector.cpp b/velox/experimental/cudf/connectors/parquet/ParquetConnector.cpp deleted file mode 100644 index a6b3e52bd606..000000000000 --- a/velox/experimental/cudf/connectors/parquet/ParquetConnector.cpp +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/experimental/cudf/connectors/parquet/ParquetConnector.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetDataSource.h" - -namespace facebook::velox::cudf_velox::connector::parquet { - -using namespace facebook::velox::connector; - -ParquetConnector::ParquetConnector( - const std::string& id, - std::shared_ptr config, - folly::Executor* executor) - : Connector(id, std::move(config)), - parquetConfig_(std::make_shared(connectorConfig())), - executor_(executor) { - LOG(INFO) << "cudf::Parquet connector " << connectorId() << " created."; -} - -std::unique_ptr ParquetConnector::createDataSource( - const RowTypePtr& outputType, - const ConnectorTableHandlePtr& tableHandle, - const ColumnHandleMap& columnHandles, - ConnectorQueryCtx* connectorQueryCtx) { - return std::make_unique( - outputType, - tableHandle, - columnHandles, - executor_, - connectorQueryCtx, - parquetConfig_); -} - -std::unique_ptr ParquetConnector::createDataSink( - RowTypePtr inputType, - ConnectorInsertTableHandlePtr connectorInsertTableHandle, - ConnectorQueryCtx* connectorQueryCtx, - CommitStrategy /*commitStrategy*/) { - auto parquetInsertHandle = - std::dynamic_pointer_cast( - connectorInsertTableHandle); - VELOX_CHECK_NOT_NULL( - parquetInsertHandle, "Parquet connector expecting parquet write handle!"); - return std::make_unique( - inputType, - parquetInsertHandle, - connectorQueryCtx, - CommitStrategy::kNoCommit, - parquetConfig_); -} - -std::shared_ptr ParquetConnectorFactory::newConnector( - const std::string& id, - std::shared_ptr config, - folly::Executor* ioExecutor, - folly::Executor* cpuExecutor) { - return std::make_shared(id, config, ioExecutor); -} - -} // namespace facebook::velox::cudf_velox::connector::parquet diff --git a/velox/experimental/cudf/connectors/parquet/ParquetConnectorSplit.cpp b/velox/experimental/cudf/connectors/parquet/ParquetConnectorSplit.cpp deleted file mode 100644 index 1dc05127659e..000000000000 --- a/velox/experimental/cudf/connectors/parquet/ParquetConnectorSplit.cpp +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/experimental/cudf/connectors/parquet/ParquetConnectorSplit.h" - -#include - -namespace facebook::velox::cudf_velox::connector::parquet { - -std::string ParquetConnectorSplit::toString() const { - return fmt::format("Parquet: {}", filePath); -} - -std::string ParquetConnectorSplit::getFileName() const { - const auto i = filePath.rfind('/'); - return i == std::string::npos ? filePath : filePath.substr(i + 1); -} - -// static -std::shared_ptr ParquetConnectorSplit::create( - const folly::dynamic& obj) { - const auto connectorId = obj["connectorId"].asString(); - const auto splitWeight = obj["splitWeight"].asInt(); - const auto filePath = obj["filePath"].asString(); - - return std::make_shared( - connectorId, filePath, splitWeight); -} - -} // namespace facebook::velox::cudf_velox::connector::parquet diff --git a/velox/experimental/cudf/exec/CMakeLists.txt b/velox/experimental/cudf/exec/CMakeLists.txt index f83c4cf940f9..1a036111026e 100644 --- a/velox/experimental/cudf/exec/CMakeLists.txt +++ b/velox/experimental/cudf/exec/CMakeLists.txt @@ -14,6 +14,7 @@ add_library( velox_cudf_exec + CudfAssignUniqueId.cpp CudfConversion.cpp CudfFilterProject.cpp CudfHashAggregation.cpp @@ -31,7 +32,14 @@ add_library( target_link_libraries( velox_cudf_exec PUBLIC cudf::cudf - PRIVATE arrow velox_arrow_bridge velox_exception velox_common_base velox_cudf_vector velox_exec + PRIVATE + arrow + velox_arrow_bridge + velox_exception + velox_common_base + velox_cudf_vector + velox_exec + velox_cudf_hive_connector ) target_compile_options(velox_cudf_exec PRIVATE -Wno-missing-field-initializers) diff --git a/velox/experimental/cudf/exec/CudfAssignUniqueId.cpp b/velox/experimental/cudf/exec/CudfAssignUniqueId.cpp new file mode 100644 index 000000000000..10b779317ec7 --- /dev/null +++ b/velox/experimental/cudf/exec/CudfAssignUniqueId.cpp @@ -0,0 +1,150 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/experimental/cudf/exec/CudfAssignUniqueId.h" + +#include + +#include + +namespace facebook::velox::cudf_velox { + +CudfAssignUniqueId::CudfAssignUniqueId( + int32_t operatorId, + exec::DriverCtx* driverCtx, + const std::shared_ptr& planNode, + int32_t uniqueTaskId, + std::shared_ptr rowIdPool) + : Operator( + driverCtx, + planNode->outputType(), + operatorId, + planNode->id(), + "CudfAssignUniqueId"), + NvtxHelper( + nvtx3::rgb{160, 82, 45}, // Sienna + operatorId, + fmt::format("[{}]", planNode->id())), + rowIdPool_(std::move(rowIdPool)) { + VELOX_USER_CHECK_LT( + uniqueTaskId, + kTaskUniqueIdLimit, + "Unique 24-bit ID specified for CudfAssignUniqueId exceeds the limit"); + uniqueValueMask_ = static_cast(uniqueTaskId) << 40; + + rowIdCounter_ = 0; + maxRowIdCounterValue_ = 0; +} + +void CudfAssignUniqueId::addInput(RowVectorPtr input) { + VELOX_NVTX_OPERATOR_FUNC_RANGE(); + auto numInput = input->size(); + VELOX_CHECK_NE( + numInput, 0, "CudfAssignUniqueId::addInput received empty set of rows"); + input_ = std::move(input); +} + +RowVectorPtr CudfAssignUniqueId::getOutput() { + VELOX_NVTX_OPERATOR_FUNC_RANGE(); + + if (input_ == nullptr) { + return nullptr; + } + + auto cudfVector = std::dynamic_pointer_cast(input_); + VELOX_CHECK(cudfVector, "Input must be a CudfVector"); + auto stream = cudfVector->stream(); + auto uniqueIdColumn = generateIdColumn( + input_->size(), stream, cudf::get_current_device_resource_ref()); + auto size = cudfVector->size(); + auto columns = cudfVector->release()->release(); + columns.push_back(std::move(uniqueIdColumn)); + auto output = std::make_shared( + input_->pool(), + outputType_, + size, + std::make_unique(std::move(columns)), + stream); + input_ = nullptr; + return output; +} + +bool CudfAssignUniqueId::isFinished() { + return noMoreInput_ && input_ == nullptr; +} + +std::unique_ptr CudfAssignUniqueId::generateIdColumn( + vector_size_t size, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) { + std::vector starts, sizes; + starts.reserve(size / kRowIdsPerRequest + 1); + sizes.reserve(size / kRowIdsPerRequest + 1); + + vector_size_t start = 0; + while (start < size) { + if (rowIdCounter_ >= maxRowIdCounterValue_) { + requestRowIds(); + } + + const auto numAvailableIds = + std::min(maxRowIdCounterValue_ - rowIdCounter_, kRowIdsPerRequest); + const vector_size_t end = + std::min(static_cast(size), start + numAvailableIds); + VELOX_CHECK_EQ( + (rowIdCounter_ + (end - start)) & uniqueValueMask_, + 0, + "Ran out of unique IDs at {}. Need {} more.", + rowIdCounter_, + (end - start)); + starts.push_back(uniqueValueMask_ | rowIdCounter_); + sizes.push_back(end - start); + + rowIdCounter_ += (end - start); + start = end; + } + + // Copy starts and sizes to device. + rmm::device_buffer d_starts_buffer( + starts.data(), starts.size() * sizeof(int64_t), stream, mr); + rmm::device_buffer d_sizes_buffer( + sizes.data(), sizes.size() * sizeof(int64_t), stream, mr); + auto d_starts_column_view = cudf::column_view( + cudf::data_type(cudf::type_id::INT64), + starts.size(), + d_starts_buffer.data(), + nullptr, + 0, + 0); + auto d_sizes_column_view = cudf::column_view( + cudf::data_type(cudf::type_id::INT64), + sizes.size(), + d_sizes_buffer.data(), + nullptr, + 0, + 0); + + auto list_sequence = cudf::lists::sequences( + d_starts_column_view, d_sizes_column_view, stream, mr); + // Discard offsets. + return std::move(list_sequence->release().children[1]); +} + +void CudfAssignUniqueId::requestRowIds() { + rowIdCounter_ = rowIdPool_->fetch_add(kRowIdsPerRequest); + maxRowIdCounterValue_ = + std::min(rowIdCounter_ + kRowIdsPerRequest, kMaxRowId); +} +} // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/exec/CudfAssignUniqueId.h b/velox/experimental/cudf/exec/CudfAssignUniqueId.h new file mode 100644 index 000000000000..b76870b563eb --- /dev/null +++ b/velox/experimental/cudf/exec/CudfAssignUniqueId.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/experimental/cudf/exec/NvtxHelper.h" +#include "velox/experimental/cudf/vector/CudfVector.h" + +#include "velox/exec/Operator.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::cudf_velox { + +class CudfAssignUniqueId : public exec::Operator, public NvtxHelper { + public: + CudfAssignUniqueId( + int32_t operatorId, + exec::DriverCtx* driverCtx, + const std::shared_ptr& planNode, + int32_t uniqueTaskId, + std::shared_ptr rowIdPool); + + bool isFilter() const override { + return true; + } + + bool preservesOrder() const override { + return true; + } + + bool needsInput() const override { + return input_ == nullptr; + } + + void addInput(RowVectorPtr input) override; + + RowVectorPtr getOutput() override; + + exec::BlockingReason isBlocked(ContinueFuture* /*future*/) override { + return exec::BlockingReason::kNotBlocked; + } + + bool startDrain() override { + // No need to drain for assignUniqueId operator. + return false; + } + + bool isFinished() override; + + private: + std::unique_ptr generateIdColumn( + vector_size_t size, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + + void requestRowIds(); + + const int64_t kRowIdsPerRequest = 1L << 20; + const int64_t kMaxRowId = 1L << 40; + const int64_t kTaskUniqueIdLimit = 1L << 24; + + int64_t uniqueValueMask_; + int64_t rowIdCounter_; + int64_t maxRowIdCounterValue_; + + std::shared_ptr rowIdPool_; +}; +} // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/exec/CudfConversion.cpp b/velox/experimental/cudf/exec/CudfConversion.cpp index 09c8f4774fd9..86f192e3d2b1 100644 --- a/velox/experimental/cudf/exec/CudfConversion.cpp +++ b/velox/experimental/cudf/exec/CudfConversion.cpp @@ -278,12 +278,7 @@ RowVectorPtr CudfToVelox::getOutput() { } // Concatenate the selected tables on the GPU - std::unique_ptr resultTable; - if (selectedInputs.size() == 1) { - resultTable = selectedInputs[0]->release(); - } else { - resultTable = getConcatenatedTable(selectedInputs, stream); - } + auto resultTable = getConcatenatedTable(selectedInputs, outputType_, stream); // Convert the concatenated table to a RowVector const auto size = resultTable->num_rows(); diff --git a/velox/experimental/cudf/exec/CudfFilterProject.cpp b/velox/experimental/cudf/exec/CudfFilterProject.cpp index c4c3785144cc..dd1989129803 100644 --- a/velox/experimental/cudf/exec/CudfFilterProject.cpp +++ b/velox/experimental/cudf/exec/CudfFilterProject.cpp @@ -14,12 +14,14 @@ * limitations under the License. */ +#include "velox/experimental/cudf/CudfConfig.h" #include "velox/experimental/cudf/exec/CudfFilterProject.h" #include "velox/experimental/cudf/exec/ToCudf.h" #include "velox/experimental/cudf/exec/VeloxCudfInterop.h" #include "velox/experimental/cudf/vector/CudfVector.h" #include "velox/expression/Expr.h" +#include "velox/expression/FieldReference.h" #include #include @@ -40,13 +42,64 @@ void debugPrintTree( debugPrintTree(input, indent + 2); } } + +bool checkAddIdentityProjection( + const core::TypedExprPtr& projection, + const RowTypePtr& inputType, + column_index_t outputChannel, + std::vector& identityProjections) { + if (auto field = core::TypedExprs::asFieldAccess(projection)) { + const auto& inputs = field->inputs(); + if (inputs.empty() || + (inputs.size() == 1 && + dynamic_cast(inputs[0].get()))) { + const auto inputChannel = inputType->getChildIdx(field->name()); + identityProjections.emplace_back(inputChannel, outputChannel); + return true; + } + } + + return false; +} + +// Split stats to attrbitute cardinality reduction to the Filter node. +std::vector splitStats( + const exec::OperatorStats& combinedStats, + const core::PlanNodeId& filterNodeId) { + exec::OperatorStats filterStats; + + filterStats.operatorId = combinedStats.operatorId; + filterStats.pipelineId = combinedStats.pipelineId; + filterStats.planNodeId = filterNodeId; + filterStats.operatorType = combinedStats.operatorType; + filterStats.numDrivers = combinedStats.numDrivers; + + filterStats.inputBytes = combinedStats.inputBytes; + filterStats.inputPositions = combinedStats.inputPositions; + filterStats.inputVectors = combinedStats.inputVectors; + + // Estimate Filter's output bytes based on cardinality change. + const double filterRate = combinedStats.inputPositions > 0 + ? (combinedStats.outputPositions * 1.0 / combinedStats.inputPositions) + : 1.0; + + filterStats.outputBytes = (uint64_t)(filterStats.inputBytes * filterRate); + filterStats.outputPositions = combinedStats.outputPositions; + filterStats.outputVectors = combinedStats.outputVectors; + + auto projectStats = combinedStats; + projectStats.inputBytes = filterStats.outputBytes; + projectStats.inputPositions = filterStats.outputPositions; + projectStats.inputVectors = filterStats.outputVectors; + + return {std::move(projectStats), std::move(filterStats)}; +} + } // namespace CudfFilterProject::CudfFilterProject( int32_t operatorId, velox::exec::DriverCtx* driverCtx, - const velox::exec::FilterProject::Export& info, - std::vector identityProjections, const std::shared_ptr& filter, const std::shared_ptr& project) : Operator( @@ -59,30 +112,78 @@ CudfFilterProject::CudfFilterProject( nvtx3::rgb{220, 20, 60}, // Crimson operatorId, fmt::format("[{}]", project ? project->id() : filter->id())), - hasFilter_(info.hasFilter), + hasFilter_(filter != nullptr), project_(project), filter_(filter) { - resultProjections_ = *(info.resultProjections); - identityProjections_ = std::move(identityProjections); + if (filter_ != nullptr && project_ != nullptr) { + folly::Synchronized& opStats = Operator::stats(); + opStats.withWLock([&](auto& stats) { + stats.setStatSplitter( + [filterId = filter_->id()](const auto& combinedStats) { + return splitStats(combinedStats, filterId); + }); + }); + } +} + +void CudfFilterProject::initialize() { + Operator::initialize(); + + std::vector allExprs; + if (hasFilter_) { + VELOX_CHECK_NOT_NULL(filter_); + allExprs.push_back(filter_->filter()); + } + + if (project_) { + const auto& inputType = project_->sources()[0]->outputType(); + + for (column_index_t i = 0; i < project_->projections().size(); i++) { + auto& projection = project_->projections()[i]; + bool identityProjection = checkAddIdentityProjection( + projection, inputType, i, identityProjections_); + if (!identityProjection) { + allExprs.push_back(projection); + resultProjections_.emplace_back(allExprs.size() - 1, i); + } + } + } else { + for (column_index_t i = 0; i < outputType_->size(); ++i) { + identityProjections_.emplace_back(i, i); + } + isIdentityProjection_ = true; + } + + auto lazyDereference = + (dynamic_cast(project_.get()) != + nullptr); + VELOX_CHECK(!(lazyDereference && filter_)); + auto expr = exec::makeExprSetFromFlag( + std::move(allExprs), operatorCtx_->execCtx(), lazyDereference); + const auto inputType = project_ ? project_->sources()[0]->outputType() : filter_->sources()[0]->outputType(); // convert to AST - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { int i = 0; - for (auto expr : info.exprs->exprs()) { + for (const auto& expr : expr->exprs()) { std::cout << "expr[" << i++ << "] " << expr->toString() << std::endl; debugPrintTree(expr); + ++i; } } std::vector> projectExprs; if (hasFilter_) { - // First expr is Filter, rest are Project - filterEvaluator_ = ExpressionEvaluator({info.exprs->exprs()[0]}, inputType); - projectExprs = {info.exprs->exprs().begin() + 1, info.exprs->exprs().end()}; + filterEvaluator_ = ExpressionEvaluator({expr->exprs()[0]}, inputType); + projectExprs = {expr->exprs().begin() + 1, expr->exprs().end()}; } - projectEvaluator_ = ExpressionEvaluator( - hasFilter_ ? projectExprs : info.exprs->exprs(), inputType); + + projectEvaluator_ = + ExpressionEvaluator(hasFilter_ ? projectExprs : expr->exprs(), inputType); + + filter_.reset(); + project_.reset(); } void CudfFilterProject::addInput(RowVectorPtr input) { @@ -114,7 +215,7 @@ RowVectorPtr CudfFilterProject::getOutput() { stream.synchronize(); auto const numColumns = outputTable->num_columns(); auto const size = outputTable->num_rows(); - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "cudfProject Output: " << size << " rows, " << numColumns << " columns " << std::endl; } diff --git a/velox/experimental/cudf/exec/CudfFilterProject.h b/velox/experimental/cudf/exec/CudfFilterProject.h index 23853ddb47e4..1c63c5b57c97 100644 --- a/velox/experimental/cudf/exec/CudfFilterProject.h +++ b/velox/experimental/cudf/exec/CudfFilterProject.h @@ -33,11 +33,12 @@ class CudfFilterProject : public exec::Operator, public NvtxHelper { CudfFilterProject( int32_t operatorId, velox::exec::DriverCtx* driverCtx, - const velox::exec::FilterProject::Export& info, - std::vector identityProjections, const std::shared_ptr& filter, const std::shared_ptr& project); + // Some is copied from operator FilterProject. + void initialize() override; + bool needsInput() const override { return !input_; } @@ -68,12 +69,15 @@ class CudfFilterProject : public exec::Operator, public NvtxHelper { private: bool allInputProcessed(); + // If true exprs_[0] is a filter and the other expressions are projections const bool hasFilter_{false}; + // Cached filter and project node for lazy initialization. After // initialization, they will be reset, and initialized_ will be set to true. std::shared_ptr project_; std::shared_ptr filter_; + ExpressionEvaluator projectEvaluator_; ExpressionEvaluator filterEvaluator_; diff --git a/velox/experimental/cudf/exec/CudfHashAggregation.cpp b/velox/experimental/cudf/exec/CudfHashAggregation.cpp index cd7b1f1befba..60c6fc412ab9 100644 --- a/velox/experimental/cudf/exec/CudfHashAggregation.cpp +++ b/velox/experimental/cudf/exec/CudfHashAggregation.cpp @@ -14,8 +14,8 @@ * limitations under the License. */ +#include "velox/experimental/cudf/CudfConfig.h" #include "velox/experimental/cudf/exec/CudfHashAggregation.h" -#include "velox/experimental/cudf/exec/ToCudf.h" #include "velox/experimental/cudf/exec/Utilities.h" #include "velox/experimental/cudf/exec/VeloxCudfInterop.h" @@ -426,7 +426,7 @@ std::unique_ptr createAggregator( VectorPtr constant, bool isGlobal, const TypePtr& resultType) { - auto prefix = cudf_velox::CudfOptions::getInstance().prefix(); + auto prefix = cudf_velox::CudfConfig::getInstance().functionNamePrefix; if (kind.rfind(prefix + "sum", 0) == 0) { return std::make_unique( step, inputIndex, constant, isGlobal, resultType); @@ -465,7 +465,7 @@ core::AggregationNode::Step getCompanionStep( std::string const& kind, core::AggregationNode::Step step) { for (const auto& [k, v] : companionStep) { - if (folly::StringPiece(kind).endsWith(k)) { + if (kind.ends_with(k)) { step = v; break; } @@ -475,7 +475,7 @@ core::AggregationNode::Step getCompanionStep( std::string getOriginalName(std::string const& kind) { for (const auto& [k, v] : companionStep) { - if (folly::StringPiece(kind).endsWith(k)) { + if (kind.ends_with(k)) { return kind.substr(0, kind.length() - k.length()); } } @@ -485,7 +485,7 @@ std::string getOriginalName(std::string const& kind) { bool hasFinalAggs( std::vector const& aggregates) { return std::any_of(aggregates.begin(), aggregates.end(), [](auto const& agg) { - return folly::StringPiece(agg.call->name()).endsWith("_merge_extract"); + return agg.call->name().ends_with("_merge_extract"); }); } @@ -579,16 +579,6 @@ auto toIntermediateAggregators( return aggregators; } -std::unique_ptr makeEmptyTable(TypePtr const& inputType) { - std::vector> emptyColumns; - for (size_t i = 0; i < inputType->size(); ++i) { - auto emptyColumn = cudf::make_empty_column( - cudf_velox::veloxToCudfTypeId(inputType->childAt(i))); - emptyColumns.push_back(std::move(emptyColumn)); - } - return std::make_unique(std::move(emptyColumns)); -} - } // namespace namespace facebook::velox::cudf_velox { @@ -938,8 +928,7 @@ RowVectorPtr CudfHashAggregation::getOutput() { auto stream = cudfGlobalStreamPool().get_stream(); - auto tbl = inputs_.empty() ? makeEmptyTable(inputType_) - : getConcatenatedTable(inputs_, stream); + auto tbl = getConcatenatedTable(inputs_, inputType_, stream); // Release input data after synchronizing. stream.synchronize(); diff --git a/velox/experimental/cudf/exec/CudfHashJoin.cpp b/velox/experimental/cudf/exec/CudfHashJoin.cpp index 8e52a8952a14..19b0b2f4f7ad 100644 --- a/velox/experimental/cudf/exec/CudfHashJoin.cpp +++ b/velox/experimental/cudf/exec/CudfHashJoin.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "velox/experimental/cudf/CudfConfig.h" #include "velox/experimental/cudf/exec/CudfHashJoin.h" #include "velox/experimental/cudf/exec/ExpressionEvaluator.h" #include "velox/experimental/cudf/exec/ToCudf.h" @@ -23,6 +24,8 @@ #include "velox/exec/Task.h" #include +#include +#include #include #include @@ -32,7 +35,7 @@ namespace facebook::velox::cudf_velox { void CudfHashJoinBridge::setHashTable( std::optional hashObject) { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinBridge::setHashTable" << std::endl; } std::vector promises; @@ -49,24 +52,24 @@ void CudfHashJoinBridge::setHashTable( std::optional CudfHashJoinBridge::hashOrFuture( ContinueFuture* future) { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinBridge::hashOrFuture" << std::endl; } std::lock_guard l(mutex_); if (hashObject_.has_value()) { return hashObject_; } - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinBridge::hashOrFuture constructing promise" << std::endl; } promises_.emplace_back("CudfHashJoinBridge::hashOrFuture"); - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinBridge::hashOrFuture getSemiFuture" << std::endl; } *future = promises_.back().getSemiFuture(); - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinBridge::hashOrFuture returning nullopt" << std::endl; } @@ -89,13 +92,13 @@ CudfHashJoinBuild::CudfHashJoinBuild( operatorId, fmt::format("[{}]", joinNode->id())), joinNode_(joinNode) { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "CudfHashJoinBuild constructor" << std::endl; } } void CudfHashJoinBuild::addInput(RowVectorPtr input) { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinBuild::addInput" << std::endl; } // Queue inputs, process all at once. @@ -117,7 +120,7 @@ void CudfHashJoinBuild::addInput(RowVectorPtr input) { } bool CudfHashJoinBuild::needsInput() const { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinBuild::needsInput" << std::endl; } return !noMoreInput_; @@ -128,7 +131,7 @@ RowVectorPtr CudfHashJoinBuild::getOutput() { } void CudfHashJoinBuild::noMoreInput() { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinBuild::noMoreInput" << std::endl; } VELOX_NVTX_OPERATOR_FUNC_RANGE(); @@ -158,22 +161,15 @@ void CudfHashJoinBuild::noMoreInput() { }; auto stream = cudfGlobalStreamPool().get_stream(); - std::unique_ptr tbl; - if (inputs_.size() == 0) { - auto emptyRowVector = RowVector::createEmpty( - joinNode_->sources()[1]->outputType(), operatorCtx_->pool()); - tbl = facebook::velox::cudf_velox::with_arrow::toCudfTable( - emptyRowVector, operatorCtx_->pool(), stream); - } else { - tbl = getConcatenatedTable(inputs_, stream); - } + auto tbl = getConcatenatedTable( + inputs_, joinNode_->sources()[1]->outputType(), stream); // Release input data after synchronizing stream.synchronize(); inputs_.clear(); VELOX_CHECK_NOT_NULL(tbl); - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Build table number of columns: " << tbl->num_columns() << std::endl; std::cout << "Build table number of rows: " << tbl->num_rows() << std::endl; @@ -202,7 +198,7 @@ void CudfHashJoinBuild::noMoreInput() { VELOX_CHECK_NOT_NULL(hashObject); } - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { if (hashObject != nullptr) { printf("hashObject is not nullptr %p\n", hashObject.get()); } else { @@ -246,7 +242,7 @@ CudfHashJoinProbe::CudfHashJoinProbe( operatorId, fmt::format("[{}]", joinNode->id())), joinNode_(joinNode) { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "CudfHashJoinProbe constructor" << std::endl; } auto probeType = joinNode_->sources()[0]->outputType(); @@ -254,7 +250,7 @@ CudfHashJoinProbe::CudfHashJoinProbe( auto const& leftKeys = joinNode_->leftKeys(); // probe keys auto const& rightKeys = joinNode_->rightKeys(); // build keys - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { for (int i = 0; i < probeType->names().size(); i++) { std::cout << "Left column " << i << ": " << probeType->names()[i] << std::endl; @@ -298,7 +294,7 @@ CudfHashJoinProbe::CudfHashJoinProbe( rightColumnOutputIndices_ = std::vector(); for (int i = 0; i < outputType->names().size(); i++) { auto const outputName = outputType->names()[i]; - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Output column " << i << ": " << outputName << std::endl; } auto channel = probeType->getChildIdxIfExists(outputName); @@ -319,7 +315,7 @@ CudfHashJoinProbe::CudfHashJoinProbe( "Join field {} not in probe or build input", outputType->children()[i]); } - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { for (int i = 0; i < leftColumnIndicesToGather_.size(); i++) { std::cout << "Left index to gather " << i << ": " << leftColumnIndicesToGather_[i] << std::endl; @@ -373,7 +369,13 @@ CudfHashJoinProbe::CudfHashJoinProbe( } bool CudfHashJoinProbe::needsInput() const { - return !finished_ && input_ == nullptr; + if (CudfConfig::getInstance().debugEnabled) { + std::cout << "Calling CudfHashJoinProbe::needsInput" << std::endl; + } + if (joinNode_->isRightJoin() || joinNode_->isRightSemiFilterJoin()) { + return !noMoreInput_; + } + return !noMoreInput_ && !finished_ && input_ == nullptr; } void CudfHashJoinProbe::addInput(RowVectorPtr input) { @@ -393,11 +395,78 @@ void CudfHashJoinProbe::addInput(RowVectorPtr input) { auto lockedStats = stats_.wlock(); lockedStats->numNullKeys += null_count; } - input_ = std::move(input); + if (!joinNode_->isRightJoin() && !joinNode_->isRightSemiFilterJoin()) { + input_ = std::move(input); + return; + } + + // Queue inputs and process all at once + if (input->size() > 0) { + inputs_.push_back(std::move(cudfInput)); + } +} + +void CudfHashJoinProbe::noMoreInput() { + if (CudfConfig::getInstance().debugEnabled) { + std::cout << "Calling CudfHashJoinProbe::noMoreInput" << std::endl; + } + VELOX_NVTX_OPERATOR_FUNC_RANGE(); + Operator::noMoreInput(); + if (!joinNode_->isRightJoin() && !joinNode_->isRightSemiFilterJoin()) { + return; + } + std::vector promises; + std::vector> peers; + // Only last driver collects all answers + if (!operatorCtx_->task()->allPeersFinished( + planNodeId(), operatorCtx_->driver(), &future_, promises, peers)) { + return; + } + // Collect results from peers + for (auto& peer : peers) { + auto op = peer->findOperator(planNodeId()); + auto* probe = dynamic_cast(op); + VELOX_CHECK_NOT_NULL(probe); + inputs_.insert(inputs_.end(), probe->inputs_.begin(), probe->inputs_.end()); + } + + SCOPE_EXIT { + // Realize the promises so that the other Drivers (which were not + // the last to finish) can continue from the barrier and finish. + peers.clear(); + for (auto& promise : promises) { + promise.setValue(); + } + }; + + auto stream = cudfGlobalStreamPool().get_stream(); + auto tbl = getConcatenatedTable( + inputs_, joinNode_->sources()[1]->outputType(), stream); + + // Release input data after synchronizing + stream.synchronize(); + + VELOX_CHECK_NOT_NULL(tbl); + + if (CudfConfig::getInstance().debugEnabled) { + std::cout << "Probe table number of columns: " << tbl->num_columns() + << std::endl; + std::cout << "Probe table number of rows: " << tbl->num_rows() << std::endl; + } + + // Store the concatenated table in input_ + input_ = std::make_shared( + operatorCtx_->pool(), + joinNode_->outputType(), + tbl->num_rows(), + std::move(tbl), + stream); + + inputs_.clear(); } RowVectorPtr CudfHashJoinProbe::getOutput() { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinProbe::getOutput" << std::endl; } VELOX_NVTX_OPERATOR_FUNC_RANGE(); @@ -413,7 +482,7 @@ RowVectorPtr CudfHashJoinProbe::getOutput() { VELOX_CHECK_NOT_NULL(cudfInput); auto stream = cudfInput->stream(); auto leftTable = cudfInput->release(); // probe table - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Probe table number of columns: " << leftTable->num_columns() << std::endl; std::cout << "Probe table number of rows: " << leftTable->num_rows() @@ -426,7 +495,7 @@ RowVectorPtr CudfHashJoinProbe::getOutput() { auto& rightTable = hashObject_.value().first; auto& hb = hashObject_.value().second; VELOX_CHECK_NOT_NULL(rightTable); - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { if (rightTable != nullptr) printf( "right_table is not nullptr %p hasValue(%d)\n", @@ -601,7 +670,7 @@ RowVectorPtr CudfHashJoinProbe::getOutput() { auto rightResult = cudf::gather(rightInput, rightIndicesCol, oobPolicy, stream); - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Left result number of columns: " << leftResult->num_columns() << std::endl; std::cout << "Right result number of columns: " @@ -640,6 +709,15 @@ bool CudfHashJoinProbe::skipProbeOnEmptyBuild() const { } exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) { + if ((joinNode_->isRightJoin() || joinNode_->isRightSemiFilterJoin()) && + hashObject_.has_value()) { + if (!future_.valid()) { + return exec::BlockingReason::kNotBlocked; + } + *future = std::move(future_); + return exec::BlockingReason::kWaitForJoinProbe; + } + if (hashObject_.has_value()) { return exec::BlockingReason::kNotBlocked; } @@ -653,7 +731,7 @@ exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) { auto hashObject = cudfJoinBridge->hashOrFuture(future); if (!hashObject.has_value()) { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "CudfHashJoinProbe is blocked, waiting for join build" << std::endl; } @@ -675,6 +753,11 @@ exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) { } } } + if ((joinNode_->isRightJoin() || joinNode_->isRightSemiFilterJoin()) && + future_.valid()) { + *future = std::move(future_); + return exec::BlockingReason::kWaitForJoinProbe; + } return exec::BlockingReason::kNotBlocked; } @@ -692,7 +775,7 @@ std::unique_ptr CudfHashJoinBridgeTranslator::toOperator( exec::DriverCtx* ctx, int32_t id, const core::PlanNodePtr& node) { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinBridgeTranslator::toOperator" << std::endl; } @@ -705,7 +788,7 @@ std::unique_ptr CudfHashJoinBridgeTranslator::toOperator( std::unique_ptr CudfHashJoinBridgeTranslator::toJoinBridge( const core::PlanNodePtr& node) { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinBridgeTranslator::toJoinBridge" << std::endl; } @@ -719,7 +802,7 @@ std::unique_ptr CudfHashJoinBridgeTranslator::toJoinBridge( exec::OperatorSupplier CudfHashJoinBridgeTranslator::toOperatorSupplier( const core::PlanNodePtr& node) { - if (cudfDebugEnabled()) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Calling CudfHashJoinBridgeTranslator::toOperatorSupplier" << std::endl; } diff --git a/velox/experimental/cudf/exec/CudfHashJoin.h b/velox/experimental/cudf/exec/CudfHashJoin.h index 00a9b9878545..3e11a307ba7a 100644 --- a/velox/experimental/cudf/exec/CudfHashJoin.h +++ b/velox/experimental/cudf/exec/CudfHashJoin.h @@ -27,8 +27,6 @@ #include #include -#include -#include #include namespace facebook::velox::cudf_velox { @@ -84,6 +82,8 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper { void addInput(RowVectorPtr input) override; + void noMoreInput() override; + RowVectorPtr getOutput() override; bool skipProbeOnEmptyBuild() const; @@ -94,7 +94,9 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper { return joinType == core::JoinType::kInner || joinType == core::JoinType::kLeft || joinType == core::JoinType::kAnti || - joinType == core::JoinType::kLeftSemiFilter; + joinType == core::JoinType::kLeftSemiFilter || + joinType == core::JoinType::kRight || + joinType == core::JoinType::kRightSemiFilter; } bool isFinished() override; @@ -109,6 +111,10 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper { bool rightPrecomputed_{false}; + // Batched probe inputs needed for right join + std::vector inputs_; + ContinueFuture future_{ContinueFuture::makeEmpty()}; + std::vector leftKeyIndices_; std::vector rightKeyIndices_; std::vector leftColumnIndicesToGather_; diff --git a/velox/experimental/cudf/exec/CudfLocalPartition.cpp b/velox/experimental/cudf/exec/CudfLocalPartition.cpp index 3a9e87cf2f62..72ff34e652fe 100644 --- a/velox/experimental/cudf/exec/CudfLocalPartition.cpp +++ b/velox/experimental/cudf/exec/CudfLocalPartition.cpp @@ -17,6 +17,7 @@ #include "velox/experimental/cudf/exec/CudfLocalPartition.h" #include "velox/experimental/cudf/vector/CudfVector.h" +#include "velox/exec/HashPartitionFunction.h" #include "velox/exec/Task.h" #include @@ -24,6 +25,15 @@ namespace facebook::velox::cudf_velox { +bool CudfLocalPartition::shouldReplace( + const std::shared_ptr& planNode) { + auto* hashFunctionSpec = dynamic_cast( + &planNode->partitionFunctionSpec()); + // Only replace LocalPartition with CudfLocalPartition for hash partitioning. + // TODO: Round Robin Row-Wise Partitioning can be supported in future. + return hashFunctionSpec; +} + CudfLocalPartition::CudfLocalPartition( int32_t operatorId, exec::DriverCtx* ctx, @@ -55,9 +65,11 @@ CudfLocalPartition::CudfLocalPartition( // Get partition function specification string std::string spec = planNode->partitionFunctionSpec().toString(); + auto* hashFunctionSpec = dynamic_cast( + &planNode->partitionFunctionSpec()); // Only parse keys if it's a hash function - if (spec.find("HASH(") != std::string::npos) { + if (hashFunctionSpec) { // Extract keys between HASH( and ) size_t start = spec.find("HASH(") + 5; size_t end = spec.find(")", start); diff --git a/velox/experimental/cudf/exec/CudfLocalPartition.h b/velox/experimental/cudf/exec/CudfLocalPartition.h index 709aa92a470d..deb302621953 100644 --- a/velox/experimental/cudf/exec/CudfLocalPartition.h +++ b/velox/experimental/cudf/exec/CudfLocalPartition.h @@ -51,6 +51,9 @@ class CudfLocalPartition : public exec::Operator, public NvtxHelper { bool isFinished() override; + static bool shouldReplace( + const std::shared_ptr& planNode); + protected: const std::vector> queues_; const size_t numPartitions_; diff --git a/velox/experimental/cudf/exec/CudfOrderBy.cpp b/velox/experimental/cudf/exec/CudfOrderBy.cpp index d395ac9d0a9a..19935c3ab3c5 100644 --- a/velox/experimental/cudf/exec/CudfOrderBy.cpp +++ b/velox/experimental/cudf/exec/CudfOrderBy.cpp @@ -77,7 +77,7 @@ void CudfOrderBy::noMoreInput() { } auto stream = cudfGlobalStreamPool().get_stream(); - auto tbl = getConcatenatedTable(inputs_, stream); + auto tbl = getConcatenatedTable(inputs_, outputType_, stream); // Release input data after synchronizing stream.synchronize(); diff --git a/velox/experimental/cudf/exec/ExpressionEvaluator.cpp b/velox/experimental/cudf/exec/ExpressionEvaluator.cpp index 094721e678e4..700e532836c3 100644 --- a/velox/experimental/cudf/exec/ExpressionEvaluator.cpp +++ b/velox/experimental/cudf/exec/ExpressionEvaluator.cpp @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/experimental/cudf/CudfConfig.h" #include "velox/experimental/cudf/exec/ExpressionEvaluator.h" -#include "velox/experimental/cudf/exec/ToCudf.h" +#include "velox/core/Expressions.h" #include "velox/expression/ConstantExpr.h" #include "velox/expression/FieldReference.h" #include "velox/type/Type.h" @@ -26,7 +27,9 @@ #include #include +#include #include +#include #include #include #include @@ -40,6 +43,7 @@ namespace facebook::velox::cudf_velox { namespace { + template cudf::ast::literal makeScalarAndLiteral( const TypePtr& type, @@ -274,20 +278,45 @@ const std::unordered_set supportedOps = { "like", "cardinality", "split", - "lower"}; + "lower", + "round", + "hash_with_seed"}; namespace detail { -bool canBeEvaluated(const std::shared_ptr& expr) { - const auto name = - stripPrefix(expr->name(), CudfOptions::getInstance().prefix()); - if (supportedOps.count(name) || binaryOps.count(name) || - unaryOps.count(name)) { - return std::all_of( - expr->inputs().begin(), expr->inputs().end(), canBeEvaluated); +bool canBeEvaluated(const core::TypedExprPtr& expr) { + switch (expr->kind()) { + case core::ExprKind::kCast: { + const auto* cast = expr->asUnchecked(); + if (cast->isTryCast()) { + return false; + } + return canBeEvaluated(cast->inputs()[0]); + } + + case core::ExprKind::kCall: { + const auto* call = expr->asUnchecked(); + const auto name = stripPrefix( + call->name(), CudfConfig::getInstance().functionNamePrefix); + if (supportedOps.count(name) || binaryOps.count(name) || + unaryOps.count(name)) { + return std::all_of( + call->inputs().begin(), call->inputs().end(), canBeEvaluated); + } + return false; + } + + case core::ExprKind::kFieldAccess: + case core::ExprKind::kDereference: + case core::ExprKind::kConstant: + return true; + + case core::ExprKind::kInput: + case core::ExprKind::kConcat: + case core::ExprKind::kLambda: + default: + return false; } - return std::dynamic_pointer_cast(expr) != - nullptr; } } // namespace detail @@ -316,7 +345,6 @@ struct AstContext { const std::shared_ptr& node = nullptr); cudf::ast::expression const& multipleInputsToPairWise( const std::shared_ptr& expr); - static bool canBeEvaluated(const std::shared_ptr& expr); }; // Create tree from Expr @@ -410,7 +438,7 @@ cudf::ast::expression const& AstContext::multipleInputsToPairWise( using Operation = cudf::ast::operation; const auto name = - stripPrefix(expr->name(), CudfOptions::getInstance().prefix()); + stripPrefix(expr->name(), CudfConfig::getInstance().functionNamePrefix); auto len = expr->inputs().size(); // Create a simple chain of operations auto result = &pushExprToTree(expr->inputs()[0]); @@ -436,7 +464,7 @@ cudf::ast::expression const& AstContext::pushExprToTree( using velox::exec::FieldReference; const auto name = - stripPrefix(expr->name(), CudfOptions::getInstance().prefix()); + stripPrefix(expr->name(), CudfConfig::getInstance().functionNamePrefix); auto len = expr->inputs().size(); auto& type = expr->type(); @@ -622,10 +650,16 @@ cudf::ast::expression const& AstContext::pushExprToTree( addPrecomputeInstructionOnSide(0, 0, "cardinality", "", node); return tree.push(Operation{Op::CAST_TO_INT64, colRef}); + } else if (name == "round") { + auto node = CudfExpressionNode::create(expr); + return addPrecomputeInstructionOnSide(0, 0, "round", "", node); } else if (name == "split") { VELOX_CHECK_EQ(len, 3); auto node = CudfExpressionNode::create(expr); return addPrecomputeInstructionOnSide(0, 0, "split", "", node); + } else if (name == "hash_with_seed") { + auto node = CudfExpressionNode::create(expr); + return addPrecomputeInstructionOnSide(0, 0, "hash_with_seed", "", node); } else if (auto fieldExpr = std::dynamic_pointer_cast(expr)) { // Refer to the appropriate side const auto fieldName = @@ -707,6 +741,39 @@ class CardinalityFunction : public CudfFunction { } }; +class RoundFunction : public CudfFunction { + public: + explicit RoundFunction(const std::shared_ptr& expr) { + const auto argSize = expr->inputs().size(); + VELOX_CHECK(argSize >= 1 && argSize <= 2, "round expects 1 or 2 inputs"); + VELOX_CHECK_NULL( + std::dynamic_pointer_cast(expr->inputs()[0]), + "round expects first column is not literal"); + if (argSize == 2) { + auto scaleExpr = + std::dynamic_pointer_cast(expr->inputs()[1]); + VELOX_CHECK_NOT_NULL(scaleExpr, "round scale must be a constant"); + scale_ = scaleExpr->value()->as>()->valueAt(0); + } + } + + ColumnOrView eval( + std::vector& inputColumns, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const override { + return cudf::round_decimal( + asView(inputColumns[0]), + scale_, + cudf::rounding_method::HALF_UP, + stream, + mr); + ; + } + + private: + int32_t scale_ = 0; +}; + class SubstrFunction : public CudfFunction { public: SubstrFunction(const std::shared_ptr& expr) { @@ -774,6 +841,45 @@ class SubstrFunction : public CudfFunction { std::unique_ptr> stepScalar_; }; +class HashFunction : public CudfFunction { + public: + HashFunction(const std::shared_ptr& expr) { + using velox::exec::ConstantExpr; + VELOX_CHECK_GE(expr->inputs().size(), 2, "hash expects at least 2 inputs"); + auto seedExpr = std::dynamic_pointer_cast(expr->inputs()[0]); + VELOX_CHECK_NOT_NULL(seedExpr, "hash seed must be a constant"); + int32_t seedValue = + seedExpr->value()->as>()->valueAt(0); + VELOX_CHECK_GE(seedValue, 0); + seedValue_ = seedValue; + } + + ColumnOrView eval( + std::vector& inputColumns, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const override { + VELOX_CHECK(!inputColumns.empty()); + auto inputTableView = convertToTableView(inputColumns); + return cudf::hashing::murmurhash3_x86_32( + inputTableView, seedValue_, stream, mr); + } + + private: + static cudf::table_view convertToTableView( + std::vector& inputColumns) { + std::vector columns; + columns.reserve(inputColumns.size()); + + for (auto& col : inputColumns) { + columns.push_back(asView(col)); + } + + return cudf::table_view(columns); + } + + uint32_t seedValue_; +}; + std::unordered_map& getCudfFunctionRegistry() { static std::unordered_map registry; @@ -840,6 +946,24 @@ bool registerBuiltinFunctions(const std::string& prefix) { return std::make_shared(expr); }); + registerCudfFunction( + prefix + "hash_with_seed", + [](const std::string&, const std::shared_ptr& expr) { + return std::make_shared(expr); + }); + + registerCudfFunction( + "hash_with_seed", + [](const std::string&, const std::shared_ptr& expr) { + return std::make_shared(expr); + }); + + registerCudfFunction( + prefix + "round", + [](const std::string&, const std::shared_ptr& expr) { + return std::make_shared(expr); + }); + return true; } @@ -1026,7 +1150,7 @@ std::vector> ExpressionEvaluator::compute( } bool ExpressionEvaluator::canBeEvaluated( - const std::vector>& exprs) { + const std::vector& exprs) { return std::all_of(exprs.begin(), exprs.end(), detail::canBeEvaluated); } diff --git a/velox/experimental/cudf/exec/ExpressionEvaluator.h b/velox/experimental/cudf/exec/ExpressionEvaluator.h index e02a98e18177..e93a8ee83883 100644 --- a/velox/experimental/cudf/exec/ExpressionEvaluator.h +++ b/velox/experimental/cudf/exec/ExpressionEvaluator.h @@ -183,8 +183,7 @@ class ExpressionEvaluator { void close(); - static bool canBeEvaluated( - const std::vector>& exprs); + static bool canBeEvaluated(const std::vector& exprs); private: std::vector exprAst_; diff --git a/velox/experimental/cudf/exec/ToCudf.cpp b/velox/experimental/cudf/exec/ToCudf.cpp index 310282ffc68d..5d348c898462 100644 --- a/velox/experimental/cudf/exec/ToCudf.cpp +++ b/velox/experimental/cudf/exec/ToCudf.cpp @@ -14,6 +14,10 @@ * limitations under the License. */ +#include "velox/experimental/cudf/CudfConfig.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnector.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveDataSource.h" +#include "velox/experimental/cudf/exec/CudfAssignUniqueId.h" #include "velox/experimental/cudf/exec/CudfConversion.h" #include "velox/experimental/cudf/exec/CudfFilterProject.h" #include "velox/experimental/cudf/exec/CudfHashAggregation.h" @@ -25,6 +29,10 @@ #include "velox/experimental/cudf/exec/ToCudf.h" #include "velox/experimental/cudf/exec/Utilities.h" +#include "folly/Conv.h" +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/TableHandle.h" +#include "velox/exec/AssignUniqueId.h" #include "velox/exec/Driver.h" #include "velox/exec/FilterProject.h" #include "velox/exec/HashAggregation.h" @@ -41,10 +49,7 @@ #include -DEFINE_bool(velox_cudf_enabled, true, "Enable cuDF-Velox acceleration"); -DEFINE_string(velox_cudf_memory_resource, "async", "Memory resource for cuDF"); -DEFINE_bool(velox_cudf_debug, false, "Enable debug printing"); -DEFINE_bool(velox_cudf_table_scan, true, "Enable cuDF table scan"); +static const std::string kCudfAdapterName = "cuDF"; namespace facebook::velox::cudf_velox { @@ -57,10 +62,10 @@ bool isAnyOf(const Base* p) { } // namespace -bool CompileState::compile() { +bool CompileState::compile(bool force_replace) { auto operators = driver_.operators(); - if (FLAGS_velox_cudf_debug) { + if (CudfConfig::getInstance().debugEnabled) { std::cout << "Operators before adapting for cuDF: count [" << operators.size() << "]" << std::endl; for (auto& op : operators) { @@ -69,10 +74,6 @@ bool CompileState::compile() { } } - // Make sure operator states are initialized. We will need to inspect some of - // them during the transformation. - driver_.initializeOperators(); - bool replacementsMade = false; auto ctx = driver_.driverCtx(); @@ -90,18 +91,42 @@ bool CompileState::compile() { return driverFactory_.consumerNode; }; - const bool isParquetConnectorRegistered = - facebook::velox::connector::getAllConnectors().count("test-parquet") > 0; - auto isTableScanSupported = - [isParquetConnectorRegistered](const exec::Operator* op) { - return isAnyOf(op) && isParquetConnectorRegistered && - cudfTableScanEnabled(); - }; + auto isTableScanSupported = [getPlanNode](const exec::Operator* op) { + if (!isAnyOf(op)) { + return false; + } + auto tableScanNode = std::dynamic_pointer_cast( + getPlanNode(op->planNodeId())); + VELOX_CHECK(tableScanNode != nullptr); + auto const& connector = velox::connector::getConnector( + tableScanNode->tableHandle()->connectorId()); + auto cudfHiveConnector = std::dynamic_pointer_cast< + facebook::velox::cudf_velox::connector::hive::CudfHiveConnector>( + connector); + if (!cudfHiveConnector) { + return false; + } + // TODO (dm): we need to ask CudfHiveConnector whether this table handle is + // supported by it. It may choose to produce a HiveDatasource. + return true; + }; - auto isFilterProjectSupported = [](const exec::Operator* op) { + auto isFilterProjectSupported = [getPlanNode](const exec::Operator* op) { if (auto filterProjectOp = dynamic_cast(op)) { - auto info = filterProjectOp->exprsAndProjection(); - return ExpressionEvaluator::canBeEvaluated(info.exprs->exprs()); + auto projectPlanNode = std::dynamic_pointer_cast( + getPlanNode(filterProjectOp->planNodeId())); + auto filterNode = filterProjectOp->filterNode(); + bool canBeEvaluated = true; + if (projectPlanNode && + !ExpressionEvaluator::canBeEvaluated( + projectPlanNode->projections())) { + canBeEvaluated = false; + } + if (canBeEvaluated && filterNode && + !ExpressionEvaluator::canBeEvaluated({filterNode->filter()})) { + canBeEvaluated = false; + } + return canBeEvaluated; } return false; }; @@ -134,7 +159,8 @@ bool CompileState::compile() { exec::HashAggregation, exec::Limit, exec::LocalPartition, - exec::LocalExchange>(op) || + exec::LocalExchange, + exec::AssignUniqueId>(op) || isFilterProjectSupported(op) || isJoinSupported(op) || isTableScanSupported(op); }; @@ -151,7 +177,8 @@ bool CompileState::compile() { exec::OrderBy, exec::HashAggregation, exec::Limit, - exec::LocalPartition>(op) || + exec::LocalPartition, + exec::AssignUniqueId>(op) || isFilterProjectSupported(op) || isJoinSupported(op); }; auto producesGpuOutput = [isFilterProjectSupported, @@ -161,7 +188,8 @@ bool CompileState::compile() { exec::OrderBy, exec::HashAggregation, exec::Limit, - exec::LocalExchange>(op) || + exec::LocalExchange, + exec::AssignUniqueId>(op) || isFilterProjectSupported(op) || (isAnyOf(op) && isJoinSupported(op)) || (isTableScanSupported(op)); @@ -189,7 +217,18 @@ bool CompileState::compile() { auto planNode = getPlanNode(oper->planNodeId()); replaceOp.push_back(std::make_unique( id, planNode->outputType(), ctx, planNode->id() + "-from-velox")); - replaceOp.back()->initialize(); + } + if (not replaceOp.empty()) { + // from-velox only, because need to inserted before current operator. + operatorsOffset += replaceOp.size(); + [[maybe_unused]] auto replaced = driverFactory_.replaceOperators( + driver_, + replacingOperatorIndex, + replacingOperatorIndex, + std::move(replaceOp)); + replacingOperatorIndex = operatorIndex + operatorsOffset; + replaceOp.clear(); + replacementsMade = true; } // This is used to denote if the current operator is kept or replaced. @@ -208,7 +247,6 @@ bool CompileState::compile() { // From-Velox (optional) replaceOp.push_back( std::make_unique(id, ctx, planNode)); - replaceOp.back()->initialize(); } else if (auto joinProbeOp = dynamic_cast(oper)) { auto planNode = std::dynamic_pointer_cast( getPlanNode(joinProbeOp->planNodeId())); @@ -216,7 +254,6 @@ bool CompileState::compile() { // From-Velox (optional) replaceOp.push_back( std::make_unique(id, ctx, planNode)); - replaceOp.back()->initialize(); // To-Velox (optional) } } else if (auto orderByOp = dynamic_cast(oper)) { @@ -225,45 +262,56 @@ bool CompileState::compile() { getPlanNode(orderByOp->planNodeId())); VELOX_CHECK(planNode != nullptr); replaceOp.push_back(std::make_unique(id, ctx, planNode)); - replaceOp.back()->initialize(); } else if (auto hashAggOp = dynamic_cast(oper)) { auto planNode = std::dynamic_pointer_cast( getPlanNode(hashAggOp->planNodeId())); VELOX_CHECK(planNode != nullptr); replaceOp.push_back( std::make_unique(id, ctx, planNode)); - replaceOp.back()->initialize(); } else if (isFilterProjectSupported(oper)) { auto filterProjectOp = dynamic_cast(oper); - auto info = filterProjectOp->exprsAndProjection(); - auto& idProjections = filterProjectOp->identityProjections(); auto projectPlanNode = std::dynamic_pointer_cast( getPlanNode(filterProjectOp->planNodeId())); - auto filterPlanNode = std::dynamic_pointer_cast( - getPlanNode(filterProjectOp->planNodeId())); - // If filter only, filter node only exists. - // If project only, or filter and project, project node only exists. + // When filter and project both exist, the FilterProject planNodeId id is + // project node id, so we need FilterProject to report the FilterNode. + auto filterPlanNode = filterProjectOp->filterNode(); VELOX_CHECK(projectPlanNode != nullptr or filterPlanNode != nullptr); replaceOp.push_back(std::make_unique( - id, ctx, info, idProjections, filterPlanNode, projectPlanNode)); - replaceOp.back()->initialize(); + id, ctx, filterPlanNode, projectPlanNode)); } else if (auto limitOp = dynamic_cast(oper)) { auto planNode = std::dynamic_pointer_cast( getPlanNode(limitOp->planNodeId())); VELOX_CHECK(planNode != nullptr); replaceOp.push_back(std::make_unique(id, ctx, planNode)); - replaceOp.back()->initialize(); } else if ( auto localPartitionOp = dynamic_cast(oper)) { auto planNode = std::dynamic_pointer_cast( getPlanNode(localPartitionOp->planNodeId())); VELOX_CHECK(planNode != nullptr); - replaceOp.push_back( - std::make_unique(id, ctx, planNode)); - replaceOp.back()->initialize(); + if (CudfLocalPartition::shouldReplace(planNode)) { + replaceOp.push_back( + std::make_unique(id, ctx, planNode)); + replaceOp.back()->initialize(); + } else { + // Round Robin batch-wise Partitioning is supported by CPU operator with + // GPU Vector. + keepOperator = 1; + } } else if ( auto localExchangeOp = dynamic_cast(oper)) { keepOperator = 1; + } else if ( + auto assignUniqueIdOp = dynamic_cast(oper)) { + auto planNode = std::dynamic_pointer_cast( + getPlanNode(assignUniqueIdOp->planNodeId())); + VELOX_CHECK(planNode != nullptr); + replaceOp.push_back(std::make_unique( + id, + ctx, + planNode, + planNode->taskUniqueId(), + planNode->uniqueIdCounter())); + replaceOp.back()->initialize(); } if (producesGpuOutput(oper) and @@ -271,10 +319,39 @@ bool CompileState::compile() { auto planNode = getPlanNode(oper->planNodeId()); replaceOp.push_back(std::make_unique( id, planNode->outputType(), ctx, planNode->id() + "-to-velox")); - replaceOp.back()->initialize(); + } + + if (force_replace) { + if (CudfConfig::getInstance().debugEnabled) { + std::printf( + "Operator: ID %d: %s, keepOperator = %d, replaceOp.size() = %ld\n", + oper->operatorId(), + oper->toString().c_str(), + keepOperator, + replaceOp.size()); + } + auto shouldSupportGpuOperator = + [isFilterProjectSupported, + isTableScanSupported](const exec::Operator* op) { + return isAnyOf< + exec::OrderBy, + exec::TableScan, + exec::HashAggregation, + exec::Limit, + exec::LocalPartition, + exec::LocalExchange, + exec::HashBuild, + exec::HashProbe>(op) || + isFilterProjectSupported(op); + }; + VELOX_CHECK( + !(keepOperator == 0 && shouldSupportGpuOperator(oper) && + replaceOp.empty()), + "Replacement with cuDF operator failed"); } if (not replaceOp.empty()) { + // ReplaceOp, to-velox. operatorsOffset += replaceOp.size() - 1 + keepOperator; [[maybe_unused]] auto replaced = driverFactory_.replaceOperators( driver_, @@ -285,7 +362,7 @@ bool CompileState::compile() { } } - if (FLAGS_velox_cudf_debug) { + if (CudfConfig::getInstance().debugEnabled) { operators = driver_.operators(); std::cout << "Operators after adapting for cuDF: count [" << operators.size() << "]" << std::endl; @@ -298,48 +375,57 @@ bool CompileState::compile() { return replacementsMade; } +std::shared_ptr mr_; + struct CudfDriverAdapter { - std::shared_ptr mr_; + bool force_replace_; - CudfDriverAdapter(std::shared_ptr mr) - : mr_(mr) {} + CudfDriverAdapter(bool force_replace) : force_replace_{force_replace} {} // Call operator needed by DriverAdapter bool operator()(const exec::DriverFactory& factory, exec::Driver& driver) { + if (!driver.driverCtx()->queryConfig().get( + CudfConfig::kCudfEnabled, CudfConfig::getInstance().enabled)) { + return false; + } auto state = CompileState(factory, driver); - auto res = state.compile(); + auto res = state.compile(force_replace_); return res; } }; static bool isCudfRegistered = false; -void registerCudf(const CudfOptions& options) { +bool cudfIsRegistered() { + return isCudfRegistered; +} + +void registerCudf() { if (cudfIsRegistered()) { return; } - if (!options.cudfEnabled) { - return; - } - registerBuiltinFunctions(options.prefix()); + registerBuiltinFunctions(CudfConfig::getInstance().functionNamePrefix); CUDF_FUNC_RANGE(); cudaFree(nullptr); // Initialize CUDA context at startup - const std::string mrMode = options.cudfMemoryResource; - auto mr = cudf_velox::createMemoryResource(mrMode, options.memoryPercent); + const std::string mrMode = CudfConfig::getInstance().memoryResource; + auto mr = cudf_velox::createMemoryResource( + mrMode, CudfConfig::getInstance().memoryPercent); cudf::set_current_device_resource(mr.get()); + mr_ = mr; exec::Operator::registerOperator( std::make_unique()); - CudfDriverAdapter cda{mr}; + CudfDriverAdapter cda{CudfConfig::getInstance().forceReplace}; exec::DriverAdapter cudfAdapter{kCudfAdapterName, {}, cda}; exec::DriverFactory::registerAdapter(cudfAdapter); isCudfRegistered = true; } void unregisterCudf() { + mr_ = nullptr; exec::DriverFactory::adapters.erase( std::remove_if( exec::DriverFactory::adapters.begin(), @@ -352,16 +438,31 @@ void unregisterCudf() { isCudfRegistered = false; } -bool cudfIsRegistered() { - return isCudfRegistered; -} - -bool cudfDebugEnabled() { - return FLAGS_velox_cudf_debug; +CudfConfig& CudfConfig::getInstance() { + static CudfConfig instance; + return instance; } -bool cudfTableScanEnabled() { - return CudfOptions::getInstance().cudfTableScan; +void CudfConfig::initialize( + std::unordered_map&& config) { + if (config.find(kCudfEnabled) != config.end()) { + enabled = folly::to(config[kCudfEnabled]); + } + if (config.find(kCudfDebugEnabled) != config.end()) { + debugEnabled = folly::to(config[kCudfDebugEnabled]); + } + if (config.find(kCudfMemoryResource) != config.end()) { + memoryResource = config[kCudfMemoryResource]; + } + if (config.find(kCudfMemoryPercent) != config.end()) { + memoryPercent = folly::to(config[kCudfMemoryPercent]); + } + if (config.find(kCudfFunctionNamePrefix) != config.end()) { + functionNamePrefix = config[kCudfFunctionNamePrefix]; + } + if (config.find(kCudfForceReplace) != config.end()) { + forceReplace = folly::to(config[kCudfForceReplace]); + } } } // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/exec/ToCudf.h b/velox/experimental/cudf/exec/ToCudf.h index 04ec3eeaa576..cf0f89855d4f 100644 --- a/velox/experimental/cudf/exec/ToCudf.h +++ b/velox/experimental/cudf/exec/ToCudf.h @@ -19,17 +19,10 @@ #include "velox/exec/Driver.h" #include "velox/exec/Operator.h" -#include - -DECLARE_bool(velox_cudf_enabled); -DECLARE_string(velox_cudf_memory_resource); -DECLARE_bool(velox_cudf_debug); -DECLARE_bool(velox_cudf_table_scan); +#include namespace facebook::velox::cudf_velox { -static const std::string kCudfAdapterName = "cuDF"; - class CompileState { public: CompileState(const exec::DriverFactory& driverFactory, exec::Driver& driver) @@ -41,61 +34,19 @@ class CompileState { // Replaces sequences of Operators in the Driver given at construction with // cuDF equivalents. Returns true if the Driver was changed. - bool compile(); + bool compile(bool force_replace); const exec::DriverFactory& driverFactory_; exec::Driver& driver_; }; -class CudfOptions { - public: - static CudfOptions& getInstance() { - static CudfOptions instance; - return instance; - } - - void setPrefix(const std::string& prefix) { - prefix_ = prefix; - } - - const std::string& prefix() const { - return prefix_; - } - - const bool cudfEnabled; - const std::string cudfMemoryResource; - const bool cudfTableScan; - // The initial percent of GPU memory to allocate for memory resource for one - // thread. - int memoryPercent; - - private: - CudfOptions() - : cudfEnabled(FLAGS_velox_cudf_enabled), - cudfMemoryResource(FLAGS_velox_cudf_memory_resource), - cudfTableScan(FLAGS_velox_cudf_table_scan), - memoryPercent(50), - prefix_("") {} - CudfOptions(const CudfOptions&) = delete; - CudfOptions& operator=(const CudfOptions&) = delete; - std::string prefix_; -}; +extern std::shared_ptr mr_; /// Registers adapter to add cuDF operators to Drivers. -void registerCudf(const CudfOptions& options = CudfOptions::getInstance()); +void registerCudf(); void unregisterCudf(); /// Returns true if cuDF is registered. bool cudfIsRegistered(); -/** - * @brief Returns true if the velox_cudf_debug flag is set to true. - */ -bool cudfDebugEnabled(); - -/** - * @brief Returns true if the velox_cudf_table_scan flag is set to true. - */ -bool cudfTableScanEnabled(); - } // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/exec/Utilities.cpp b/velox/experimental/cudf/exec/Utilities.cpp index fe8678e3ec31..3dc85b7cb113 100644 --- a/velox/experimental/cudf/exec/Utilities.cpp +++ b/velox/experimental/cudf/exec/Utilities.cpp @@ -15,7 +15,9 @@ */ #include "velox/experimental/cudf/exec/Utilities.h" +#include "velox/experimental/cudf/exec/VeloxCudfInterop.h" +#include #include #include #include @@ -48,9 +50,8 @@ namespace { makeCudaMr(), rmm::percent_of_free_device_memory(percent)); } -[[nodiscard]] auto makeAsyncMr(int percent) { - return std::make_shared( - rmm::percent_of_free_device_memory(percent)); +[[nodiscard]] auto makeAsyncMr() { + return std::make_shared(); } [[nodiscard]] auto makeManagedMr() { @@ -76,7 +77,7 @@ std::shared_ptr createMemoryResource( if (mode == "pool") return makePoolMr(percent); if (mode == "async") - return makeAsyncMr(percent); + return makeAsyncMr(); if (mode == "arena") return makeArenaMr(percent); if (mode == "managed") @@ -112,11 +113,37 @@ std::unique_ptr concatenateTables( tableViews, stream, cudf::get_current_device_resource_ref()); } +std::unique_ptr makeEmptyTable(TypePtr const& inputType) { + std::vector> emptyColumns; + for (size_t i = 0; i < inputType->size(); ++i) { + if (auto const& childType = inputType->childAt(i); + childType->kind() == TypeKind::ROW) { + auto tbl = makeEmptyTable(childType); + auto structColumn = std::make_unique( + cudf::data_type(cudf::type_id::STRUCT), + 0, + rmm::device_buffer(), + rmm::device_buffer(), + 0, + tbl->release()); + emptyColumns.push_back(std::move(structColumn)); + } else { + auto emptyColumn = cudf::make_empty_column( + cudf_velox::veloxToCudfTypeId(inputType->childAt(i))); + emptyColumns.push_back(std::move(emptyColumn)); + } + } + return std::make_unique(std::move(emptyColumns)); +} + std::unique_ptr getConcatenatedTable( std::vector& tables, + const TypePtr& tableType, rmm::cuda_stream_view stream) { // Check for empty vector - VELOX_CHECK_GT(tables.size(), 0); + if (tables.size() == 0) { + return makeEmptyTable(tableType); + } auto inputStreams = std::vector(); auto tableViews = std::vector(); diff --git a/velox/experimental/cudf/exec/Utilities.h b/velox/experimental/cudf/exec/Utilities.h index d1cd239af459..5e242af37e64 100644 --- a/velox/experimental/cudf/exec/Utilities.h +++ b/velox/experimental/cudf/exec/Utilities.h @@ -53,6 +53,7 @@ createMemoryResource(std::string_view mode, int percent); // stream. Inputs are not safe to use after calling this function. [[nodiscard]] std::unique_ptr getConcatenatedTable( std::vector& tables, + const TypePtr& tableType, rmm::cuda_stream_view stream); } // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/tests/AggregationTest.cpp b/velox/experimental/cudf/tests/AggregationTest.cpp index 7ad8548de21c..4f5c621d9383 100644 --- a/velox/experimental/cudf/tests/AggregationTest.cpp +++ b/velox/experimental/cudf/tests/AggregationTest.cpp @@ -667,15 +667,15 @@ TEST_F(EmptyInputAggregationTest, groupedPartialFinalAggregation) { plan_ = PlanBuilder() .values({data_}) .filter(filter_) - .partialAggregation({"c2"}, {"sum(c0)", "count(c1)", "max(c1)"}) + .partialAggregation( + {"c2"}, {"sum(c0)", "count(c1)", "max(c1)", "avg(c1)"}) .finalAggregation() .planNode(); - // TODO (dm): "avg(c1)" // should return empty result for partial-final aggregation assertQuery( plan_, - "SELECT c2, sum(c0), count(c1), max(c1) FROM tmp WHERE c0 > 10 GROUP BY c2"); + "SELECT c2, sum(c0), count(c1), max(c1), avg(c1) FROM tmp WHERE c0 > 10 GROUP BY c2"); } TEST_F(EmptyInputAggregationTest, globalPartialFinalAggregation) { @@ -684,14 +684,15 @@ TEST_F(EmptyInputAggregationTest, globalPartialFinalAggregation) { plan_ = PlanBuilder() .values({data_}) .filter(filter_) - .partialAggregation({}, {"sum(c0)", "count(c1)", "max(c1)"}) + .partialAggregation( + {}, {"sum(c0)", "count(c1)", "max(c1)", "avg(c1)"}) .finalAggregation() .planNode(); - // TODO (dm): "avg(c1)" // global partial-final aggregation should return 1 row with null/zero values assertQuery( - plan_, "SELECT sum(c0), count(c1), max(c1) FROM tmp WHERE c0 > 10"); + plan_, + "SELECT sum(c0), count(c1), max(c1), avg(c1) FROM tmp WHERE c0 > 10"); } } // namespace facebook::velox::exec::test diff --git a/velox/experimental/cudf/tests/AssignUniqueIdTest.cpp b/velox/experimental/cudf/tests/AssignUniqueIdTest.cpp new file mode 100644 index 000000000000..6d4e84ad1da4 --- /dev/null +++ b/velox/experimental/cudf/tests/AssignUniqueIdTest.cpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/experimental/cudf/exec/CudfConversion.h" +#include "velox/experimental/cudf/exec/ToCudf.h" + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/exec/PlanNodeStats.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/QueryAssertions.h" + +namespace facebook::velox::exec { + +using namespace facebook::velox::test; +using namespace facebook::velox::exec::test; + +namespace { + +class AssignUniqueIdTest : public HiveConnectorTestBase { + protected: + void SetUp() override { + HiveConnectorTestBase::SetUp(); + cudf_velox::registerCudf(); + } + + void TearDown() override { + cudf_velox::unregisterCudf(); + HiveConnectorTestBase::TearDown(); + } + + void verifyUniqueId( + const std::shared_ptr& plan, + const std::vector& input) { + CursorParameters params; + params.planNode = plan; + params.queryConfigs.insert( + {cudf_velox::CudfFromVelox::kGpuBatchSizeRows, "1"}); + auto result = readCursor(params); + ASSERT_EQ(result.second[0]->childrenSize(), input[0]->childrenSize() + 1); + verifyUniqueId(input, result.second); + + auto task = result.first->task(); + // Verify number of memory allocations. As no memory is allocated for the + // unique ID vector generation in GPU, the number of memory allocations + // should be 0. cudf can generate new columns, and cannot reuse the memory. + auto stats = toPlanStats(task->taskStats()); + ASSERT_EQ(0, stats.at(uniqueNodeId_).numMemoryAllocations); + // check if CudfAssignUniqueId operator is executed with stats + ASSERT_EQ( + 1, stats.at(uniqueNodeId_).operatorStats.count("CudfAssignUniqueId")); + } + + void verifyUniqueId( + const std::vector& input, + const std::vector& vectors) { + auto numColumns = vectors[0]->childrenSize(); + ASSERT_EQ(numColumns, input[0]->childrenSize() + 1); + + std::set ids; + for (int i = 0; i < numColumns; i++) { + for (auto batch = 0; batch < vectors.size(); ++batch) { + auto column = vectors[batch]->childAt(i); + if (i < numColumns - 1) { + assertEqualVectors(input[batch]->childAt(i), column); + } else { + auto idValues = column->asFlatVector()->rawValues(); + std::copy( + idValues, + idValues + column->size(), + std::inserter(ids, ids.end())); + } + } + } + + vector_size_t totalInputSize = 0; + for (const auto& vector : input) { + totalInputSize += vector->size(); + } + + ASSERT_EQ(totalInputSize, ids.size()); + } + + core::PlanNodeId uniqueNodeId_; +}; + +TEST_F(AssignUniqueIdTest, multiBatch) { + vector_size_t batchSize = 1000; + std::vector input; + input.reserve(3); + for (int i = 0; i < 3; ++i) { + input.push_back( + makeRowVector({makeFlatVector(batchSize, folly::identity)})); + } + + auto plan = PlanBuilder() + .values(input) + .assignUniqueId() + .capturePlanNodeId(uniqueNodeId_) + .planNode(); + + verifyUniqueId(plan, input); +} + +TEST_F(AssignUniqueIdTest, exceedRequestLimit) { + vector_size_t requestLimit = 1 << 20L; + auto input = { + makeRowVector( + {makeFlatVector(requestLimit - 10, folly::identity)}), + makeRowVector({makeFlatVector(100, folly::identity)}), + makeRowVector({makeFlatVector(100, folly::identity)}), + }; + + auto plan = PlanBuilder() + .values(input) + .assignUniqueId() + .capturePlanNodeId(uniqueNodeId_) + .planNode(); + + verifyUniqueId(plan, input); +} + +TEST_F(AssignUniqueIdTest, multiThread) { + for (int i = 0; i < 3; i++) { + vector_size_t batchSize = 1000; + auto input = { + makeRowVector({makeFlatVector(batchSize, folly::identity)})}; + auto plan = PlanBuilder() + .values(input, true) + .assignUniqueId() + .capturePlanNodeId(uniqueNodeId_) + .planNode(); + + std::shared_ptr task; + auto result = AssertQueryBuilder(plan) + .config(cudf_velox::CudfFromVelox::kGpuBatchSizeRows, "1") + .maxDrivers(8) + .copyResults(pool(), task); + ASSERT_EQ(batchSize * 8, result->size()); + + std::set ids; + auto idValues = + result->children().back()->asFlatVector()->rawValues(); + std::copy( + idValues, idValues + result->size(), std::inserter(ids, ids.end())); + + ASSERT_EQ(batchSize * 8, ids.size()); + + // Verify number of memory allocations. As no memory is allocated for the + // unique ID vector generation in GPU, the number of memory allocations + // should be 0. cudf can generate new columns, and cannot reuse the memory. + auto stats = toPlanStats(task->taskStats()); + ASSERT_EQ(0, stats.at(uniqueNodeId_).numMemoryAllocations); + } +} + +TEST_F(AssignUniqueIdTest, maxRowIdLimit) { + auto input = {makeRowVector({makeFlatVector({1, 2, 3})})}; + + auto plan = PlanBuilder().values(input).assignUniqueId().planNode(); + + // Increase the counter to kMaxRowId. + std::dynamic_pointer_cast(plan) + ->uniqueIdCounter() + ->fetch_add(1L << 40); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan).copyResults(pool()), + "Ran out of unique IDs at 1099511627776"); +} + +TEST_F(AssignUniqueIdTest, taskUniqueIdLimit) { + auto input = {makeRowVector({makeFlatVector({1, 2, 3})})}; + + auto plan = + PlanBuilder().values(input).assignUniqueId("unique", 1L << 24).planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan).copyResults(pool()), + "(16777216 vs. 16777216) Unique 24-bit ID specified for AssignUniqueId exceeds the limit"); +} + +// TODO: Add test for barrier execution, other operators does not support +// draining yet. +} // namespace +} // namespace facebook::velox::exec diff --git a/velox/experimental/cudf/tests/CMakeLists.txt b/velox/experimental/cudf/tests/CMakeLists.txt index e5cbd6950830..093fab7aeee9 100644 --- a/velox/experimental/cudf/tests/CMakeLists.txt +++ b/velox/experimental/cudf/tests/CMakeLists.txt @@ -16,11 +16,14 @@ add_executable(velox_cudf_hash_join_test HashJoinTest.cpp Main.cpp) add_executable(velox_cudf_order_by_test Main.cpp OrderByTest.cpp) add_executable(velox_cudf_aggregation_test Main.cpp AggregationTest.cpp) add_executable(velox_cudf_table_scan_test Main.cpp TableScanTest.cpp) -add_executable(velox_cudf_table_write_test Main.cpp TableWriteTest.cpp) +# Disabling writer tests until we re-add writing ability to CudfHiveConnector +# add_executable(velox_cudf_table_write_test Main.cpp TableWriteTest.cpp) add_executable(velox_cudf_local_partition_test Main.cpp LocalPartitionTest.cpp) add_executable(velox_cudf_filter_project_test Main.cpp FilterProjectTest.cpp) add_executable(velox_cudf_subfield_filter_ast_test Main.cpp SubfieldFilterAstTest.cpp) add_executable(velox_cudf_limit_test Main.cpp LimitTest.cpp) +add_executable(velox_cudf_config_test Main.cpp ConfigTest.cpp) +add_executable(velox_cudf_assign_unique_id_test Main.cpp AssignUniqueIdTest.cpp) add_test( NAME velox_cudf_hash_join_test @@ -52,11 +55,11 @@ add_test( WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} ) -add_test( - NAME velox_cudf_table_write_test - COMMAND velox_cudf_table_write_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} -) +# add_test( +# NAME velox_cudf_table_write_test +# COMMAND velox_cudf_table_write_test +# WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +# ) add_test( NAME velox_cudf_filter_project_test @@ -76,15 +79,29 @@ add_test( WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} ) +add_test( + NAME velox_cudf_assign_unique_id_test + COMMAND velox_cudf_assign_unique_id_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +add_test( + NAME velox_cudf_config_test + COMMAND velox_cudf_config_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + set_tests_properties(velox_cudf_hash_join_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) set_tests_properties(velox_cudf_order_by_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) set_tests_properties(velox_cudf_aggregation_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) set_tests_properties(velox_cudf_local_partition_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) set_tests_properties(velox_cudf_table_scan_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) -set_tests_properties(velox_cudf_table_write_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) +# set_tests_properties(velox_cudf_table_write_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) set_tests_properties(velox_cudf_filter_project_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) set_tests_properties(velox_cudf_subfield_filter_ast_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) set_tests_properties(velox_cudf_limit_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) +set_tests_properties(velox_cudf_config_test PROPERTIES LABELS cuda_driver TIMEOUT 300) +set_tests_properties(velox_cudf_assign_unique_id_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) target_link_libraries( velox_cudf_hash_join_test @@ -136,7 +153,7 @@ target_link_libraries( target_link_libraries( velox_cudf_table_scan_test velox_cudf_exec_test_lib - velox_cudf_parquet_connector + velox_cudf_hive_connector velox_exec velox_exec_test_lib velox_test_util @@ -145,20 +162,30 @@ target_link_libraries( fmt::fmt ) +# target_link_libraries( +# velox_cudf_table_write_test +# velox_cudf_exec_test_lib +# velox_cudf_hive_connector +# velox_exec +# velox_exec_test_lib +# velox_test_util +# gtest +# gtest_main +# fmt::fmt +# ) + target_link_libraries( - velox_cudf_table_write_test - velox_cudf_exec_test_lib - velox_cudf_parquet_connector + velox_cudf_filter_project_test + velox_cudf_exec velox_exec velox_exec_test_lib velox_test_util gtest gtest_main - fmt::fmt ) target_link_libraries( - velox_cudf_filter_project_test + velox_cudf_subfield_filter_ast_test velox_cudf_exec velox_exec velox_exec_test_lib @@ -168,7 +195,7 @@ target_link_libraries( ) target_link_libraries( - velox_cudf_subfield_filter_ast_test + velox_cudf_limit_test velox_cudf_exec velox_exec velox_exec_test_lib @@ -177,8 +204,9 @@ target_link_libraries( gtest_main ) +target_link_libraries(velox_cudf_config_test velox_cudf_exec gtest gtest_main) target_link_libraries( - velox_cudf_limit_test + velox_cudf_assign_unique_id_test velox_cudf_exec velox_exec velox_exec_test_lib diff --git a/velox/experimental/cudf/tests/ConfigTest.cpp b/velox/experimental/cudf/tests/ConfigTest.cpp new file mode 100644 index 000000000000..c4279acf2571 --- /dev/null +++ b/velox/experimental/cudf/tests/ConfigTest.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/experimental/cudf/CudfConfig.h" + +#include + +namespace facebook::velox::cudf_velox::test { + +TEST(ConfigTest, CudfConfig) { + std::unordered_map options = { + {CudfConfig::kCudfEnabled, "false"}, + {CudfConfig::kCudfDebugEnabled, "true"}, + {CudfConfig::kCudfMemoryResource, "arena"}, + {CudfConfig::kCudfMemoryPercent, "25"}, + {CudfConfig::kCudfFunctionNamePrefix, "presto"}, + {CudfConfig::kCudfForceReplace, "true"}}; + + CudfConfig config; + config.initialize(std::move(options)); + ASSERT_EQ(config.enabled, false); + ASSERT_EQ(config.debugEnabled, true); + ASSERT_EQ(config.memoryResource, "arena"); + ASSERT_EQ(config.memoryPercent, 25); + ASSERT_EQ(config.functionNamePrefix, "presto"); + ASSERT_EQ(config.forceReplace, true); +} +} // namespace facebook::velox::cudf_velox::test diff --git a/velox/experimental/cudf/tests/FilterProjectTest.cpp b/velox/experimental/cudf/tests/FilterProjectTest.cpp index e591b933a895..581a404d2e85 100644 --- a/velox/experimental/cudf/tests/FilterProjectTest.cpp +++ b/velox/experimental/cudf/tests/FilterProjectTest.cpp @@ -515,10 +515,9 @@ TEST_F(CudfFilterProjectTest, yearFunction) { testYearFunction(vectors); } -TEST_F(CudfFilterProjectTest, DISABLED_caseWhenOperation) { +TEST_F(CudfFilterProjectTest, caseWhenOperation) { vector_size_t batchSize = 1000; auto vectors = makeVectors(rowType_, 2, batchSize); - // failing because switch copies nulls too. createDuckDbTable(vectors); testCaseWhenOperation(vectors); @@ -636,6 +635,32 @@ TEST_F(CudfFilterProjectTest, mixedInOperation) { testMixedInOperation(vectors); } +TEST_F(CudfFilterProjectTest, round) { + auto data = makeRowVector({makeFlatVector({4123, 456789098})}); + parse::ParseOptions options; + options.parseIntegerAsBigint = false; + auto plan = PlanBuilder() + .setParseOptions(options) + .values({data}) + .project({"round(c0, 2) as c1"}) + .planNode(); + AssertQueryBuilder(plan).assertResults(data); + plan = PlanBuilder() + .setParseOptions(options) + .values({data}) + .project({"round(c0) as c1"}) + .planNode(); + AssertQueryBuilder(plan).assertResults(data); + + plan = PlanBuilder() + .setParseOptions(options) + .values({data}) + .project({"round(c0, -3) as c1"}) + .planNode(); + auto expected = makeRowVector({makeFlatVector({4000, 456789000})}); + AssertQueryBuilder(plan).assertResults(expected); +} + TEST_F(CudfFilterProjectTest, simpleFilter) { vector_size_t batchSize = 1000; auto vectors = makeVectors(rowType_, 2, batchSize); diff --git a/velox/experimental/cudf/tests/HashJoinTest.cpp b/velox/experimental/cudf/tests/HashJoinTest.cpp index 15bf0b15a8fd..9de0113ef683 100644 --- a/velox/experimental/cudf/tests/HashJoinTest.cpp +++ b/velox/experimental/cudf/tests/HashJoinTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "velox/experimental/cudf/CudfConfig.h" #include "velox/experimental/cudf/exec/ToCudf.h" #include "folly/experimental/EventCount.h" @@ -54,6 +55,7 @@ class HashJoinTest : public HashJoinTestBase { void SetUp() override { HashJoinTestBase::SetUp(); + cudf_velox::CudfConfig::getInstance().forceReplace = true; cudf_velox::registerCudf(); } @@ -421,6 +423,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithLargeOutput) { }); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) .numDrivers(numDrivers_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) @@ -727,6 +730,7 @@ TEST_P(MultiThreadedHashJoinTest, leftSemiJoinFilterWithExtraFilter) { TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilter) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) .numDrivers(numDrivers_) .probeType(probeType_) .probeVectors(133, 3) @@ -826,6 +830,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithAllMatches) { }); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) .numDrivers(numDrivers_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) @@ -861,6 +866,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { auto testProbeVectors = probeVectors; auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) .numDrivers(numDrivers_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) @@ -871,10 +877,6 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { .joinOutputLayout({"u0", "u1"}) .referenceQuery( "SELECT u.* FROM u WHERE EXISTS (SELECT t0 FROM t WHERE u0 = t0 AND t1 > -1)") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - ASSERT_EQ( - getOutputPositions(task, "HashProbe"), 200 * 5 * numDrivers_); - }) .run(); } @@ -883,6 +885,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { auto testProbeVectors = probeVectors; auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) .numDrivers(numDrivers_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) @@ -893,9 +896,6 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { .joinOutputLayout({"u0", "u1"}) .referenceQuery( "SELECT u.* FROM u WHERE EXISTS (SELECT t0 FROM t WHERE u0 = t0 AND t1 > 100000)") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - ASSERT_EQ(getOutputPositions(task, "HashProbe"), 0); - }) .run(); } @@ -904,6 +904,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { auto testProbeVectors = probeVectors; auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) .numDrivers(numDrivers_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) @@ -914,10 +915,6 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { .joinOutputLayout({"u0", "u1"}) .referenceQuery( "SELECT u.* FROM u WHERE EXISTS (SELECT t0 FROM t WHERE u0 = t0 AND t1 % 5 = 0)") - .verifier([&](const std::shared_ptr& task, bool hasSpill) { - ASSERT_EQ( - getOutputPositions(task, "HashProbe"), 200 / 5 * 5 * numDrivers_); - }) .run(); } } @@ -985,6 +982,7 @@ TEST_P(MultiThreadedHashJoinTest, semiFilterOverLazyVectors) { .run(); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) .planNode(flipJoinSides(plan)) .inputSplits(splitInput) .checkSpillStats(false) @@ -1017,6 +1015,7 @@ TEST_P(MultiThreadedHashJoinTest, semiFilterOverLazyVectors) { .run(); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) .planNode(flipJoinSides(plan)) .inputSplits(splitInput) .checkSpillStats(false) @@ -1123,35 +1122,38 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilter) { }); }); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"t0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u0"}) - .buildVectors(std::move(buildVectors)) - .joinType(core::JoinType::kAnti) - .nullAware(true) - .joinFilter("t1 != u1") - .joinOutputLayout({"t0", "t1"}) - .referenceQuery( - "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0 AND t1 <> u1)") - .checkSpillStats(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - // Verify spilling is not triggered in case of null-aware anti-join - // with filter. - const auto statsPair = taskSpilledStats(*task); - ASSERT_EQ(statsPair.first.spilledRows, 0); - ASSERT_EQ(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 0); - ASSERT_EQ(statsPair.first.spilledFiles, 0); - ASSERT_EQ(statsPair.second.spilledRows, 0); - ASSERT_EQ(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 0); - ASSERT_EQ(statsPair.second.spilledFiles, 0); - verifyTaskSpilledRuntimeStats(*task, false); - ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); - }) - .run(); + // null-anti join with filter not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .probeKeys({"t0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::move(buildVectors)) + .joinType(core::JoinType::kAnti) + .nullAware(true) + .joinFilter("t1 != u1") + .joinOutputLayout({"t0", "t1"}) + .referenceQuery( + "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t0 = u0 AND t1 <> u1)") + .checkSpillStats(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + // Verify spilling is not triggered in case of null-aware anti-join + // with filter. + const auto statsPair = taskSpilledStats(*task); + ASSERT_EQ(statsPair.first.spilledRows, 0); + ASSERT_EQ(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 0); + ASSERT_EQ(statsPair.first.spilledFiles, 0); + ASSERT_EQ(statsPair.second.spilledRows, 0); + ASSERT_EQ(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 0); + ASSERT_EQ(statsPair.second.spilledFiles, 0); + verifyTaskSpilledRuntimeStats(*task, false); + ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); + }) + .run(), + "Replacement with cuDF operator failed"); } TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterAndEmptyBuild) { @@ -1176,37 +1178,40 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterAndEmptyBuild) { }); }); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) - .numDrivers(numDrivers_) - .probeKeys({"t0"}) - .probeVectors(std::vector(probeVectors)) - .buildKeys({"u0"}) - .buildVectors(std::vector(buildVectors)) - .buildFilter("u0 < 0") - .joinType(core::JoinType::kAnti) - .nullAware(true) - .joinFilter("u1 > t1") - .joinOutputLayout({"t0", "t1"}) - .referenceQuery( - "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u0 < 0 AND u.u0 = t.t0)") - .checkSpillStats(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - // Verify spilling is not triggered in case of null-aware anti-join - // with filter. - const auto statsPair = taskSpilledStats(*task); - ASSERT_EQ(statsPair.first.spilledRows, 0); - ASSERT_EQ(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 0); - ASSERT_EQ(statsPair.first.spilledFiles, 0); - ASSERT_EQ(statsPair.second.spilledRows, 0); - ASSERT_EQ(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 0); - ASSERT_EQ(statsPair.second.spilledFiles, 0); - verifyTaskSpilledRuntimeStats(*task, false); - ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); - }) - .run(); + // null-anti join with filter not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) + .numDrivers(numDrivers_) + .probeKeys({"t0"}) + .probeVectors(std::vector(probeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::vector(buildVectors)) + .buildFilter("u0 < 0") + .joinType(core::JoinType::kAnti) + .nullAware(true) + .joinFilter("u1 > t1") + .joinOutputLayout({"t0", "t1"}) + .referenceQuery( + "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u0 < 0 AND u.u0 = t.t0)") + .checkSpillStats(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + // Verify spilling is not triggered in case of null-aware + // anti-join with filter. + const auto statsPair = taskSpilledStats(*task); + ASSERT_EQ(statsPair.first.spilledRows, 0); + ASSERT_EQ(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 0); + ASSERT_EQ(statsPair.first.spilledFiles, 0); + ASSERT_EQ(statsPair.second.spilledRows, 0); + ASSERT_EQ(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 0); + ASSERT_EQ(statsPair.second.spilledFiles, 0); + verifyTaskSpilledRuntimeStats(*task, false); + ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); + }) + .run(), + "Replacement with cuDF operator failed"); } } @@ -1236,34 +1241,37 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterAndNullKey) { auto testProbeVectors = probeVectors; auto testBuildVectors = buildVectors; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"t0"}) - .probeVectors(std::move(testProbeVectors)) - .buildKeys({"u0"}) - .buildVectors(std::move(testBuildVectors)) - .joinType(core::JoinType::kAnti) - .nullAware(true) - .joinFilter(filter) - .joinOutputLayout({"t0", "t1"}) - .referenceQuery(referenceSql) - .checkSpillStats(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - // Verify spilling is not triggered in case of null-aware anti-join - // with filter. - const auto statsPair = taskSpilledStats(*task); - ASSERT_EQ(statsPair.first.spilledRows, 0); - ASSERT_EQ(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 0); - ASSERT_EQ(statsPair.first.spilledFiles, 0); - ASSERT_EQ(statsPair.second.spilledRows, 0); - ASSERT_EQ(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 0); - ASSERT_EQ(statsPair.second.spilledFiles, 0); - verifyTaskSpilledRuntimeStats(*task, false); - ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); - }) - .run(); + // null-anti join with filter not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .probeKeys({"t0"}) + .probeVectors(std::move(testProbeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::move(testBuildVectors)) + .joinType(core::JoinType::kAnti) + .nullAware(true) + .joinFilter(filter) + .joinOutputLayout({"t0", "t1"}) + .referenceQuery(referenceSql) + .checkSpillStats(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + // Verify spilling is not triggered in case of null-aware + // anti-join with filter. + const auto statsPair = taskSpilledStats(*task); + ASSERT_EQ(statsPair.first.spilledRows, 0); + ASSERT_EQ(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 0); + ASSERT_EQ(statsPair.first.spilledFiles, 0); + ASSERT_EQ(statsPair.second.spilledRows, 0); + ASSERT_EQ(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 0); + ASSERT_EQ(statsPair.second.spilledFiles, 0); + verifyTaskSpilledRuntimeStats(*task, false); + ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); + }) + .run(), + "Replacement with cuDF operator failed"); } } @@ -1296,19 +1304,22 @@ TEST_P( auto testProbeVectors = probeVectors; auto testBuildVectors = buildVectors; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"t0"}) - .probeVectors(std::move(testProbeVectors)) - .buildKeys({"u0"}) - .buildVectors(std::move(testBuildVectors)) - .joinType(core::JoinType::kAnti) - .nullAware(true) - .joinFilter(filter) - .joinOutputLayout({"t0", "t1"}) - .referenceQuery(referenceSql) - .checkSpillStats(false) - .run(); + // null-anti join with filter not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .probeKeys({"t0"}) + .probeVectors(std::move(testProbeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::move(testBuildVectors)) + .joinType(core::JoinType::kAnti) + .nullAware(true) + .joinFilter(filter) + .joinOutputLayout({"t0", "t1"}) + .referenceQuery(referenceSql) + .checkSpillStats(false) + .run(), + "Replacement with cuDF operator failed"); } } @@ -1334,34 +1345,37 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterOnNullableColumn) { makeFlatVector(234, folly::identity, nullEvery(91)), }); }); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"t0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u0"}) - .buildVectors(std::move(buildVectors)) - .joinType(core::JoinType::kAnti) - .nullAware(true) - .joinFilter(joinFilter) - .joinOutputLayout({"t0", "t1"}) - .referenceQuery(referenceSql) - .checkSpillStats(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - // Verify spilling is not triggered in case of null-aware anti-join - // with filter. - const auto statsPair = taskSpilledStats(*task); - ASSERT_EQ(statsPair.first.spilledRows, 0); - ASSERT_EQ(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 0); - ASSERT_EQ(statsPair.first.spilledFiles, 0); - ASSERT_EQ(statsPair.second.spilledRows, 0); - ASSERT_EQ(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 0); - ASSERT_EQ(statsPair.second.spilledFiles, 0); - verifyTaskSpilledRuntimeStats(*task, false); - ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); - }) - .run(); + // null-anti join with filter not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .probeKeys({"t0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::move(buildVectors)) + .joinType(core::JoinType::kAnti) + .nullAware(true) + .joinFilter(joinFilter) + .joinOutputLayout({"t0", "t1"}) + .referenceQuery(referenceSql) + .checkSpillStats(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + // Verify spilling is not triggered in case of null-aware + // anti-join with filter. + const auto statsPair = taskSpilledStats(*task); + ASSERT_EQ(statsPair.first.spilledRows, 0); + ASSERT_EQ(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 0); + ASSERT_EQ(statsPair.first.spilledFiles, 0); + ASSERT_EQ(statsPair.second.spilledRows, 0); + ASSERT_EQ(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 0); + ASSERT_EQ(statsPair.second.spilledFiles, 0); + verifyTaskSpilledRuntimeStats(*task, false); + ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); + }) + .run(), + "Replacement with cuDF operator failed"); } { @@ -1384,34 +1398,37 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterOnNullableColumn) { makeFlatVector(234, folly::identity, nullEvery(37)), }); }); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .probeKeys({"t0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u0"}) - .buildVectors(std::move(buildVectors)) - .joinType(core::JoinType::kAnti) - .nullAware(true) - .joinFilter(joinFilter) - .joinOutputLayout({"t0", "t1"}) - .referenceQuery(referenceSql) - .checkSpillStats(false) - .verifier([&](const std::shared_ptr& task, bool /*unused*/) { - // Verify spilling is not triggered in case of null-aware anti-join - // with filter. - const auto statsPair = taskSpilledStats(*task); - ASSERT_EQ(statsPair.first.spilledRows, 0); - ASSERT_EQ(statsPair.first.spilledBytes, 0); - ASSERT_EQ(statsPair.first.spilledPartitions, 0); - ASSERT_EQ(statsPair.first.spilledFiles, 0); - ASSERT_EQ(statsPair.second.spilledRows, 0); - ASSERT_EQ(statsPair.second.spilledBytes, 0); - ASSERT_EQ(statsPair.second.spilledPartitions, 0); - ASSERT_EQ(statsPair.second.spilledFiles, 0); - verifyTaskSpilledRuntimeStats(*task, false); - ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); - }) - .run(); + // null-anti join with filter not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .probeKeys({"t0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::move(buildVectors)) + .joinType(core::JoinType::kAnti) + .nullAware(true) + .joinFilter(joinFilter) + .joinOutputLayout({"t0", "t1"}) + .referenceQuery(referenceSql) + .checkSpillStats(false) + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + // Verify spilling is not triggered in case of null-aware + // anti-join with filter. + const auto statsPair = taskSpilledStats(*task); + ASSERT_EQ(statsPair.first.spilledRows, 0); + ASSERT_EQ(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 0); + ASSERT_EQ(statsPair.first.spilledFiles, 0); + ASSERT_EQ(statsPair.second.spilledRows, 0); + ASSERT_EQ(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 0); + ASSERT_EQ(statsPair.second.spilledFiles, 0); + verifyTaskSpilledRuntimeStats(*task, false); + ASSERT_EQ(maxHashBuildSpillLevel(*task), -1); + }) + .run(), + "Replacement with cuDF operator failed"); } } @@ -2042,6 +2059,23 @@ TEST_P(MultiThreadedHashJoinTest, rightJoin) { .joinOutputLayout({"c0", "c1", "u_c1"}) .referenceQuery( "SELECT t.c0, t.c1, u.c1 FROM t RIGHT JOIN u ON t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + int nullJoinBuildKeyCount = 0; + int nullJoinProbeKeyCount = 0; + + for (auto& pipeline : task->taskStats().pipelineStats) { + for (auto op : pipeline.operatorStats) { + if (op.operatorType == "CudfHashJoinBuild") { + nullJoinBuildKeyCount += op.numNullKeys; + } + if (op.operatorType == "CudfHashJoinProbe") { + nullJoinProbeKeyCount += op.numNullKeys; + } + } + } + ASSERT_GT(nullJoinBuildKeyCount, 0); + ASSERT_GT(nullJoinProbeKeyCount, 0); + }) .run(); } @@ -2269,19 +2303,22 @@ TEST_P(MultiThreadedHashJoinTest, fullJoin) { }); }); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .injectSpill(false) - .probeKeys({"c0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(buildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kFull) - .joinOutputLayout({"c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0") - .run(); + // full join not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .injectSpill(false) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kFull) + .joinOutputLayout({"c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0") + .run(), + "Replacement with cuDF operator failed"); } TEST_P(MultiThreadedHashJoinTest, fullJoinWithEmptyBuild) { @@ -2324,22 +2361,25 @@ TEST_P(MultiThreadedHashJoinTest, fullJoinWithEmptyBuild) { }); }); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) - .numDrivers(numDrivers_) - .injectSpill(false) - .probeKeys({"c0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(buildVectors)) - .buildFilter("c0 > 100") - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kFull) - .joinOutputLayout({"c1"}) - .referenceQuery( - "SELECT t.c1 FROM t FULL OUTER JOIN (SELECT * FROM u WHERE c0 > 100) u ON t.c0 = u.c0") - .checkSpillStats(false) - .run(); + // full join not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) + .numDrivers(numDrivers_) + .injectSpill(false) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildFilter("c0 > 100") + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kFull) + .joinOutputLayout({"c1"}) + .referenceQuery( + "SELECT t.c1 FROM t FULL OUTER JOIN (SELECT * FROM u WHERE c0 > 100) u ON t.c0 = u.c0") + .checkSpillStats(false) + .run(), + "Replacement with cuDF operator failed"); } } @@ -2379,20 +2419,23 @@ TEST_P(MultiThreadedHashJoinTest, fullJoinWithNoMatch) { }); }); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .injectSpill(false) - .probeKeys({"c0"}) - .probeVectors(std::move(probeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(buildVectors)) - .buildFilter("c0 < 0") - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kFull) - .joinOutputLayout({"c1"}) - .referenceQuery( - "SELECT t.c1 FROM t FULL OUTER JOIN (SELECT * FROM u WHERE c0 < 0) u ON t.c0 = u.c0") - .run(); + // full join not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .injectSpill(false) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildFilter("c0 < 0") + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kFull) + .joinOutputLayout({"c1"}) + .referenceQuery( + "SELECT t.c1 FROM t FULL OUTER JOIN (SELECT * FROM u WHERE c0 < 0) u ON t.c0 = u.c0") + .run(), + "Replacement with cuDF operator failed"); } TEST_P(MultiThreadedHashJoinTest, fullJoinWithFilters) { @@ -2435,40 +2478,46 @@ TEST_P(MultiThreadedHashJoinTest, fullJoinWithFilters) { { auto testProbeVectors = probeVectors; auto testBuildVectors = buildVectors; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .injectSpill(false) - .probeKeys({"c0"}) - .probeVectors(std::move(testProbeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(testBuildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kFull) - .joinFilter("(c1 + u_c1) % 2 = 1") - .joinOutputLayout({"c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 1") - .run(); + // full join not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .injectSpill(false) + .probeKeys({"c0"}) + .probeVectors(std::move(testProbeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(testBuildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kFull) + .joinFilter("(c1 + u_c1) % 2 = 1") + .joinOutputLayout({"c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 1") + .run(), + "Replacement with cuDF operator failed"); } // Filter without passed rows. { auto testProbeVectors = probeVectors; auto testBuildVectors = buildVectors; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .numDrivers(numDrivers_) - .injectSpill(false) - .probeKeys({"c0"}) - .probeVectors(std::move(testProbeVectors)) - .buildKeys({"u_c0"}) - .buildVectors(std::move(testBuildVectors)) - .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) - .joinType(core::JoinType::kFull) - .joinFilter("(c1 + u_c1) % 2 = 3") - .joinOutputLayout({"c0", "c1", "u_c1"}) - .referenceQuery( - "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 3") - .run(); + // full join not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .injectSpill(false) + .probeKeys({"c0"}) + .probeVectors(std::move(testProbeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(testBuildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kFull) + .joinFilter("(c1 + u_c1) % 2 = 3") + .joinOutputLayout({"c0", "c1", "u_c1"}) + .referenceQuery( + "SELECT t.c0, t.c1, u.c1 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (t.c1 + u.c1) % 2 = 3") + .run(), + "Replacement with cuDF operator failed"); } } @@ -2561,12 +2610,15 @@ TEST_F(HashJoinTest, nullAwareRightSemiProjectOverScan) { {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}}, }; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery("SELECT u0, u0 IN (SELECT t0 FROM t) FROM u") - .run(); + // right semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .inputSplits(splitInput) + .checkSpillStats(false) + .referenceQuery("SELECT u0, u0 IN (SELECT t0 FROM t) FROM u") + .run(), + "Replacement with cuDF operator failed"); } } @@ -2598,6 +2650,7 @@ TEST_F(HashJoinTest, duplicateJoinKeys) { const std::vector& rightKeys, const std::vector& outputLayout, core::JoinType joinType, + bool throwType, const std::string& query) { auto plan = PlanBuilder(planNodeIdGenerator) .values(leftVectors) @@ -2613,18 +2666,32 @@ TEST_F(HashJoinTest, duplicateJoinKeys) { outputLayout, joinType) .planNode(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .injectSpill(false) - .checkSpillStats(false) - .referenceQuery(query) - .run(); + if (throwType) { + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery(query) + .run(), + "Replacement with cuDF operator failed"); + } else { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery(query) + .run(); + } }; std::vector> joins = { {core::JoinType::kInner, "INNER JOIN"}, {core::JoinType::kLeft, "LEFT JOIN"}, - {core::JoinType::kRight, "RIGHT JOIN"}, + {core::JoinType::kRight, "RIGHT JOIN"}}; + + // full outer join not supported + std::vector> throwingJoins = { {core::JoinType::kFull, "FULL OUTER JOIN"}}; for (const auto& [joinType, joinTypeSql] : joins) { @@ -2636,6 +2703,7 @@ TEST_F(HashJoinTest, duplicateJoinKeys) { {"u0", "u0"}, // rightKeys {"t0", "t1", "u0"}, // outputLayout joinType, + false, "SELECT t.c0, t.c1, u.c0 FROM t " + joinTypeSql + " u ON t.c0 = u.c0 and t.c1 = u.c0"); } @@ -2649,6 +2717,31 @@ TEST_F(HashJoinTest, duplicateJoinKeys) { {"u0", "u1"}, // rightKeys {"t0", "u0", "u1"}, // outputLayout joinType, + false, + "SELECT t.c0, u.c0, u.c1 FROM t " + joinTypeSql + + " u ON t.c0 = u.c0 and t.c0 = u.c1"); + } + + for (const auto& [joinType, joinTypeSql] : throwingJoins) { + // Duplicate keys on the build side. + assertPlan( + {"c0 AS t0", "c1 as t1"}, // leftProject + {"t0", "t1"}, // leftKeys + {"c0 AS u0"}, // rightProject + {"u0", "u0"}, // rightKeys + {"t0", "t1", "u0"}, // outputLayout + joinType, + true, + "SELECT t.c0, t.c1, u.c0 FROM t " + joinTypeSql + + " u ON t.c0 = u.c0 and t.c1 = u.c0"); + assertPlan( + {"c0 AS t0"}, // leftProject + {"t0", "t0"}, // leftKeys + {"c0 AS u0", "c1 AS u1"}, // rightProject + {"u0", "u1"}, // rightKeys + {"t0", "u0", "u1"}, // outputLayout + joinType, + true, "SELECT t.c0, u.c0, u.c1 FROM t " + joinTypeSql + " u ON t.c0 = u.c0 and t.c0 = u.c1"); } @@ -2692,21 +2785,27 @@ TEST_F(HashJoinTest, semiProject) { core::JoinType::kLeftSemiProject) .planNode(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(plan) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(plan) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") + .run(), + "Replacement with cuDF operator failed"); // With extra filter. planNodeIdGenerator = std::make_shared(); @@ -2725,21 +2824,27 @@ TEST_F(HashJoinTest, semiProject) { core::JoinType::kLeftSemiProject) .planNode(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(plan) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND t.c1 * 10 <> u.c1) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(plan) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND t.c1 * 10 <> u.c1) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND t.c1 * 10 <> u.c1) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND t.c1 * 10 <> u.c1) FROM t") + .run(), + "Replacement with cuDF operator failed"); // Empty build side. planNodeIdGenerator = std::make_shared(); @@ -2759,26 +2864,32 @@ TEST_F(HashJoinTest, semiProject) { core::JoinType::kLeftSemiProject) .planNode(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(plan) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE u.c0 < 0 AND t.c0 = u.c0) FROM t") - // NOTE: there is no spilling in empty build test case as all the - // build-side rows have been filtered out. - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(plan) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE u.c0 < 0 AND t.c0 = u.c0) FROM t") + // NOTE:, there is no spilling in empty build test case as all the + // build-side rows have been filtered out. + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE u.c0 < 0 AND t.c0 = u.c0) FROM t") - // NOTE: there is no spilling in empty build test case as all the - // build-side rows have been filtered out. - .checkSpillStats(false) - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE u.c0 < 0 AND t.c0 = u.c0) FROM t") + // NOTE: there is no spilling in empty build test case as all the + // build-side rows have been filtered out. + .checkSpillStats(false) + .run(), + "Replacement with cuDF operator failed"); } TEST_F(HashJoinTest, semiProjectWithNullKeys) { @@ -2835,185 +2946,245 @@ TEST_F(HashJoinTest, semiProjectWithNullKeys) { // Null join keys on both sides. auto plan = makePlan(false /*nullAware*/); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(plan) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(plan) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t") + .run(), + "Replacement with cuDF operator failed"); plan = makePlan(true /*nullAware*/); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(plan) - .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(plan) + .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(flipJoinSides(plan)) - .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(flipJoinSides(plan)) + .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") + .run(), + "Replacement with cuDF operator failed"); // Null join keys on build side-only. plan = makePlan(false /*nullAware*/, "t0 IS NOT NULL"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(plan) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t WHERE t0 IS NOT NULL") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(plan) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t WHERE t0 IS NOT NULL") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t WHERE t0 IS NOT NULL") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0) FROM t WHERE t0 IS NOT NULL") + .run(), + "Replacement with cuDF operator failed"); plan = makePlan(true /*nullAware*/, "t0 IS NOT NULL"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(plan) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t WHERE t0 IS NOT NULL") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(plan) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t WHERE t0 IS NOT NULL") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t WHERE t0 IS NOT NULL") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t WHERE t0 IS NOT NULL") + .run(), + "Replacement with cuDF operator failed"); // Null join keys on probe side-only. plan = makePlan(false /*nullAware*/, "", "u0 IS NOT NULL"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(plan) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NOT NULL) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(plan) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NOT NULL) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NOT NULL) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NOT NULL) FROM t") + .run(), + "Replacement with cuDF operator failed"); plan = makePlan(true /*nullAware*/, "", "u0 IS NOT NULL"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(plan) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NOT NULL) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(plan) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NOT NULL) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .injectSpill(false) - .checkSpillStats(false) - .planNode(flipJoinSides(plan)) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NOT NULL) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .injectSpill(false) + .checkSpillStats(false) + .planNode(flipJoinSides(plan)) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NOT NULL) FROM t") + .run(), + "Replacement with cuDF operator failed"); // Empty build side. plan = makePlan(false /*nullAware*/, "", "u0 < 0"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(plan) - .injectSpill(false) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 < 0) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(plan) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 < 0) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(flipJoinSides(plan)) - .injectSpill(false) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 < 0) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(flipJoinSides(plan)) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 < 0) FROM t") + .run(), + "Replacement with cuDF operator failed"); plan = makePlan(true /*nullAware*/, "", "u0 < 0"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(plan) - .injectSpill(false) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 < 0) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(plan) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 < 0) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(flipJoinSides(plan)) - .injectSpill(false) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 < 0) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(flipJoinSides(plan)) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 < 0) FROM t") + .run(), + "Replacement with cuDF operator failed"); // Build side with all rows having null join keys. plan = makePlan(false /*nullAware*/, "", "u0 IS NULL"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(plan) - .injectSpill(false) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NULL) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(plan) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NULL) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(flipJoinSides(plan)) - .injectSpill(false) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NULL) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(flipJoinSides(plan)) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE u0 = t0 AND u0 IS NULL) FROM t") + .run(), + "Replacement with cuDF operator failed"); plan = makePlan(true /*nullAware*/, "", "u0 IS NULL"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(plan) - .injectSpill(false) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NULL) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(plan) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NULL) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) - .planNode(flipJoinSides(plan)) - .injectSpill(false) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NULL) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, executor_.get()) + .planNode(flipJoinSides(plan)) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE u0 IS NULL) FROM t") + .run(), + "Replacement with cuDF operator failed"); } TEST_F(HashJoinTest, semiProjectWithFilter) { @@ -3063,24 +3234,31 @@ TEST_F(HashJoinTest, semiProjectWithFilter) { for (const auto& filter : filters) { auto plan = makePlan(true /*nullAware*/, filter); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery(fmt::format( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE {}) FROM t", filter)) - .injectSpill(false) - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery(fmt::format( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE {}) FROM t", + filter)) + .injectSpill(false) + .run(), + "Replacement with cuDF operator failed"); plan = makePlan(false /*nullAware*/, filter); // DuckDB Exists operator returns NULL when u0 or t0 is NULL. We exclude // these values. - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .referenceQuery(fmt::format( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE (u0 is not null OR t0 is not null) AND u0 = t0 AND {}) FROM t", - filter)) - .injectSpill(false) - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .referenceQuery(fmt::format( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE (u0 is not null OR t0 is not null) AND u0 = t0 AND {}) FROM t", + filter)) + .injectSpill(false) + .run(), + "Replacement with cuDF operator failed"); } } @@ -3130,7 +3308,8 @@ TEST_F(HashJoinTest, leftSemiJoinWithExtraOutputCapacity) { createDuckDbTable("u", buildVectors); auto runQuery = [&](const std::string& query, const std::string& filter, - core::JoinType joinType) { + core::JoinType joinType, + bool throwType) { auto planNodeIdGenerator = std::make_shared(); std::vector outputLayout = {"t0", "t1"}; if (joinType == core::JoinType::kLeftSemiProject) { @@ -3149,12 +3328,23 @@ TEST_F(HashJoinTest, leftSemiJoinWithExtraOutputCapacity) { joinType, false) .planNode(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .config(core::QueryConfig::kPreferredOutputBatchRows, "5") - .referenceQuery(query) - .injectSpill(false) - .run(); + if (throwType) { + VELOX_ASSERT_RUNTIME_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .config(core::QueryConfig::kPreferredOutputBatchRows, "5") + .referenceQuery(query) + .injectSpill(false) + .run(), + "Replacement with cuDF operator failed"); + } else { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .config(core::QueryConfig::kPreferredOutputBatchRows, "5") + .referenceQuery(query) + .injectSpill(false) + .run(); + } }; { SCOPED_TRACE("left semi filter join"); @@ -3164,7 +3354,8 @@ TEST_F(HashJoinTest, leftSemiJoinWithExtraOutputCapacity) { "SELECT t0, t1 FROM t WHERE EXISTS (SELECT u0 FROM u WHERE t0 = u0 AND {})", filter), filter, - core::JoinType::kLeftSemiFilter); + core::JoinType::kLeftSemiFilter, + false); } { @@ -3174,7 +3365,8 @@ TEST_F(HashJoinTest, leftSemiJoinWithExtraOutputCapacity) { fmt::format( "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE {}) FROM t", filter), filter, - core::JoinType::kLeftSemiProject); + core::JoinType::kLeftSemiProject, + true); } } @@ -3283,19 +3475,25 @@ TEST_F(HashJoinTest, semiProjectOverLazyVectors) { {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}}, }; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .inputSplits(splitInput) + .checkSpillStats(false) + .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") - .run(); + // right semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .inputSplits(splitInput) + .checkSpillStats(false) + .referenceQuery("SELECT t0, t1, t0 IN (SELECT u0 FROM u) FROM t") + .run(), + "Replacement with cuDF operator failed"); // With extra filter. planNodeIdGenerator = std::make_shared(); @@ -3314,21 +3512,27 @@ TEST_F(HashJoinTest, semiProjectOverLazyVectors) { core::JoinType::kLeftSemiProject) .planNode(); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") - .run(); + // left semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .inputSplits(splitInput) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") + .run(), + "Replacement with cuDF operator failed"); - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(flipJoinSides(plan)) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") - .run(); + // right semi project not supported + VELOX_ASSERT_THROW( + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(flipJoinSides(plan)) + .inputSplits(splitInput) + .checkSpillStats(false) + .referenceQuery( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE (t1 + u1) % 3 = 0) FROM t") + .run(), + "Replacement with cuDF operator failed"); } VELOX_INSTANTIATE_TEST_SUITE_P( @@ -3527,22 +3731,28 @@ TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterFullJoin) { // Test the case where a filter loads a subset of the rows that will be output // from a column on the probe side. - testLazyVectorsWithFilter( - core::JoinType::kFull, - "c1 > 0 AND c2 > 0", - {"c1", "c2"}, - "SELECT t.c1, t.c2 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (c1 > 0 AND c2 > 0)"); + // full join not supported + VELOX_ASSERT_THROW( + testLazyVectorsWithFilter( + core::JoinType::kFull, + "c1 > 0 AND c2 > 0", + {"c1", "c2"}, + "SELECT t.c1, t.c2 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (c1 > 0 AND c2 > 0)"), + "Replacement with cuDF operator failed"); } TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftSemiProject) { // Test the case where a filter loads a subset of the rows that will be output // from a column on the probe side. - testLazyVectorsWithFilter( - core::JoinType::kLeftSemiProject, - "c1 > 0 AND c2 > 0", - {"c1", "c2", "match"}, - "SELECT t.c1, t.c2, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND (t.c1 > 0 AND t.c2 > 0)) FROM t"); + // left semi project not supported + VELOX_ASSERT_THROW( + testLazyVectorsWithFilter( + core::JoinType::kLeftSemiProject, + "c1 > 0 AND c2 > 0", + {"c1", "c2", "match"}, + "SELECT t.c1, t.c2, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND (t.c1 > 0 AND t.c2 > 0)) FROM t"), + "Replacement with cuDF operator failed"); } TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterAntiJoin) { diff --git a/velox/experimental/cudf/tests/LocalPartitionTest.cpp b/velox/experimental/cudf/tests/LocalPartitionTest.cpp index 53b33df2b4f0..39136cc5f7c4 100644 --- a/velox/experimental/cudf/tests/LocalPartitionTest.cpp +++ b/velox/experimental/cudf/tests/LocalPartitionTest.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/experimental/cudf/exec/CudfConversion.h" #include "velox/experimental/cudf/exec/ToCudf.h" #include "velox/common/base/tests/GTestUtils.h" @@ -158,5 +159,321 @@ TEST_F(LocalPartitionTest, partition) { queryBuilder.assertResults("SELECT c0, max(c0) FROM tmp GROUP BY 1"); } +TEST_F(LocalPartitionTest, unionAllLocalExchange) { + auto data1 = makeRowVector({"d0"}, {makeFlatVector({"x"})}); + auto data2 = makeRowVector({"e0"}, {makeFlatVector({"y"})}); + for (bool serialExecutionMode : {false, true}) { + SCOPED_TRACE(fmt::format("serialExecutionMode {}", serialExecutionMode)); + auto planNodeIdGenerator = std::make_shared(); + AssertQueryBuilder queryBuilder(duckDbQueryRunner_); + // applyTestParameters(queryBuilder); + queryBuilder.serialExecution(serialExecutionMode) + .plan(PlanBuilder(planNodeIdGenerator) + .localPartitionRoundRobin( + {PlanBuilder(planNodeIdGenerator) + .values({data1}) + .project({"d0 as c0"}) + .planNode(), + PlanBuilder(planNodeIdGenerator) + .values({data2}) + .project({"e0 as c0"}) + .planNode()}) + .project({"length(c0)"}) + .planNode()) + .assertResults( + "SELECT length(c0) FROM (" + " SELECT * FROM (VALUES ('x')) as t1(c0) UNION ALL " + " SELECT * FROM (VALUES ('y')) as t2(c0)" + ")"); + } +} + +TEST_F(LocalPartitionTest, roundRobinMultipleBatches) { + // Test round robin with multiple input batches to verify counter continuity + std::vector vectors = { + makeRowVector({makeFlatSequence(0, 5)}), // Batch 1: 5 rows + makeRowVector({makeFlatSequence(5, 5)}), // Batch 2: 5 rows + makeRowVector({makeFlatSequence(10, 5)}), // Batch 3: 5 rows + }; + + auto planNodeIdGenerator = std::make_shared(); + + auto op = + PlanBuilder(planNodeIdGenerator) + .localPartitionRoundRobin( + {PlanBuilder(planNodeIdGenerator).values(vectors).planNode()}) + .singleAggregation({}, {"count(1)", "min(c0)", "max(c0)"}) + .planNode(); + + // Total 15 rows, 3 partitions: each gets 5 rows + auto task = assertQuery(op, "SELECT 15, 0, 14"); + + auto stats = task->taskStats(); + ASSERT_EQ(stats.numTotalDrivers, 2); +} + +TEST_F(LocalPartitionTest, roundRobinEmptyInput) { + // Test with empty input + std::vector vectors = { + makeRowVector({makeFlatVector({})}), // Empty vector + }; + + auto planNodeIdGenerator = std::make_shared(); + + auto op = + PlanBuilder(planNodeIdGenerator) + .localPartitionRoundRobin( + {PlanBuilder(planNodeIdGenerator).values(vectors).planNode()}) + .singleAggregation({}, {"count(1)"}) + .planNode(); + + // Should handle empty input gracefully + auto task = assertQuery(op, "SELECT 0"); + + auto stats = task->taskStats(); + ASSERT_EQ(stats.numTotalDrivers, 2); +} + +TEST_F(LocalPartitionTest, roundRobinMultipleSources) { + // Test round robin with multiple input sources + std::vector vectors1 = { + makeRowVector({makeFlatSequence(0, 5)}), + }; + + std::vector vectors2 = { + makeRowVector({makeFlatSequence(5, 5)}), + }; + + auto planNodeIdGenerator = std::make_shared(); + + auto op = + PlanBuilder(planNodeIdGenerator) + .localPartitionRoundRobin({ + PlanBuilder(planNodeIdGenerator).values(vectors1).planNode(), + PlanBuilder(planNodeIdGenerator).values(vectors2).planNode(), + }) + .singleAggregation({}, {"count(1)", "min(c0)", "max(c0)"}) + .planNode(); + + // Total 10 rows from 2 sources, 3 partitions + auto task = assertQuery(op, "SELECT 10, 0, 9"); + + auto stats = task->taskStats(); + ASSERT_EQ(stats.numTotalDrivers, 3); +} + +TEST_F(LocalPartitionTest, roundRobinWithAggregation) { + // Test round robin followed by aggregation + std::vector vectors = { + makeRowVector({ + makeFlatSequence(0, 20), // 20 rows + makeFlatSequence(0, 4, 20), // Values 0,1,2,3 repeating + }), + }; + + auto planNodeIdGenerator = std::make_shared(); + + auto op = + PlanBuilder(planNodeIdGenerator) + .localPartitionRoundRobin( + {PlanBuilder(planNodeIdGenerator).values(vectors).planNode()}) + .singleAggregation({}, {"count(1)", "sum(c0)"}) + .planNode(); + + auto task = assertQuery(op, "SELECT 20, 190"); + + auto stats = task->taskStats(); + ASSERT_EQ(stats.numTotalDrivers, 2); +} + +TEST_F(LocalPartitionTest, roundRobinWithTableScan) { + // Test round robin with table scan sources + std::vector vectors = { + makeRowVector({makeFlatSequence(0, 30)}), + makeRowVector({makeFlatSequence(0, 12)}), + }; + + auto filePaths = writeToFiles(vectors); + auto rowType = asRowType(vectors[0]->type()); + + auto planNodeIdGenerator = std::make_shared(); + std::vector scanNodeIds; + + auto tableScanNode = [&]() { + auto node = PlanBuilder(planNodeIdGenerator).tableScan(rowType).planNode(); + scanNodeIds.push_back(node->id()); + return node; + }; + + auto op = PlanBuilder(planNodeIdGenerator) + .localPartitionRoundRobin({ + tableScanNode(), + tableScanNode(), + }) + .singleAggregation({}, {"count(1)", "min(c0)", "max(c0)"}) + .planNode(); + + createDuckDbTable(vectors); + + AssertQueryBuilder queryBuilder(op, duckDbQueryRunner_); + queryBuilder.maxDrivers(3); // 3 partitions + queryBuilder.config(core::QueryConfig::kMaxLocalExchangePartitionCount, "3"); + + for (auto i = 0; i < filePaths.size(); ++i) { + queryBuilder.split( + scanNodeIds[i], makeHiveConnectorSplit(filePaths[i]->getPath())); + } + + // 1st partition gets all rows + // 2nd partition gets zero rows + // 3rd partition gets zero rows + + auto task = queryBuilder.assertResults( + "SELECT 42, 0, 29 UNION ALL SELECT 0, NULL, NULL UNION ALL SELECT 0, NULL, NULL"); + + auto stats = task->taskStats(); + ASSERT_EQ(stats.numTotalDrivers, 9); +} + +TEST_F(LocalPartitionTest, roundRobinAllCombinations) { + // 6 value vectors in different scenarios. + // Test all possible round robin distribution configurations: + for (auto [num_sources, num_partitions] : + {std::pair{3, 1}, + {2, 3}, + {3, 2}, + {1, 3}, + {2, 2}, + {2, 5}, + {1, 5}, + {5, 2}, + {5, 3}}) { + // Test to verify actual distribution pattern + std::vector vectors = { + makeRowVector({makeFlatSequence(0, 12)}), // 12 rows + makeRowVector({makeFlatSequence(0, 4)}), // 4 rows + makeRowVector({makeFlatSequence(0, 40)}), // 40 rows + makeRowVector({makeFlatSequence(0, 8)}), // 8 rows + makeRowVector({makeFlatSequence(0, 6)}), // 6 rows + makeRowVector({makeFlatSequence(0, 25)}), // 25 rows + }; + auto expectedTotalRows = std::accumulate( + vectors.begin(), + vectors.end(), + 0, + [](int sum, const RowVectorPtr& vector) { + return sum + vector->size(); + }); + std::vector expected_partition_counts(num_partitions, 0); + for (int i = 0; i < vectors.size(); ++i) { + expected_partition_counts[i % num_partitions] += 1; + } + + auto planNodeIdGenerator = std::make_shared(); + + auto make_op = [&](int num_sources) { + std::vector sources; + for (size_t i = 0; i < num_sources; ++i) { + auto num_vectors = (vectors.size() + num_sources - 1) / num_sources; + auto start = std::min(i * num_vectors, vectors.size()); + auto end = std::min(start + num_vectors, vectors.size()); + std::vector source_vectors( + vectors.begin() + start, vectors.begin() + end); + if (source_vectors.empty()) { + source_vectors.push_back( + makeRowVector({makeFlatSequence(0, 0)})); + } + sources.push_back( + PlanBuilder(planNodeIdGenerator).values(source_vectors).planNode()); + } + return PlanBuilder(planNodeIdGenerator) + .localPartitionRoundRobin(sources) + .planNode(); + }; + + auto op = make_op(num_sources); + + // Use cursor to examine actual distribution + CursorParameters params; + params.queryConfigs.insert( + {cudf_velox::CudfFromVelox::kGpuBatchSizeRows, "1"}); + params.queryConfigs.insert( + {core::QueryConfig::kMaxLocalExchangePartitionCount, + std::to_string(num_partitions)}); + + params.planNode = op; + params.maxDrivers = num_partitions; + params.copyResult = false; + + auto cursor = TaskCursor::create(params); + + std::vector partitionCounts(num_partitions, 0); + int partition = 0; + int totalRows = 0; + + while (cursor->moveNext()) { + auto* batch = cursor->current()->as(); + partitionCounts[partition++ % num_partitions] += 1; + totalRows += batch->size(); + } + + EXPECT_EQ(totalRows, expectedTotalRows) + << "for {" << num_sources << ", " << num_partitions << "}"; + // With round robin, each batch should go into separate partitions. + for (int i = 0; i < num_partitions; ++i) { + auto expected_count = expected_partition_counts[i]; + ASSERT_EQ(partitionCounts[i], expected_count) + << "for i " << i << ", {" << num_sources << ", " << num_partitions + << "}"; + } + } +} + +TEST_F(LocalPartitionTest, roundRobinDistributionVerification) { + // Test to verify actual distribution pattern + std::vector vectors = { + makeRowVector({makeFlatSequence(0, 12)}), // 12 rows + makeRowVector({makeFlatSequence(0, 4)}), // 4 rows + makeRowVector({makeFlatSequence(0, 40)}), // 40 rows + }; + + auto planNodeIdGenerator = std::make_shared(); + + auto op = + PlanBuilder(planNodeIdGenerator) + .localPartitionRoundRobin( + {PlanBuilder(planNodeIdGenerator).values(vectors).planNode()}) + .planNode(); + + // Use cursor to examine actual distribution + CursorParameters params; + params.queryConfigs.insert( + {cudf_velox::CudfFromVelox::kGpuBatchSizeRows, "1"}); + params.queryConfigs.insert( + {core::QueryConfig::kMaxLocalExchangePartitionCount, "3"}); + + params.planNode = op; + params.maxDrivers = 3; + params.copyResult = false; + + auto cursor = TaskCursor::create(params); + + std::vector partitionCounts(3, 0); + int partition = 0; + int totalRows = 0; + + while (cursor->moveNext()) { + auto* batch = cursor->current()->as(); + partitionCounts[partition++ % 3] += 1; + totalRows += batch->size(); + } + + ASSERT_EQ(totalRows, 56); + // With round robin, each batch should go into separate partitions. + ASSERT_EQ(partitionCounts[0], 1); + ASSERT_EQ(partitionCounts[1], 1); + ASSERT_EQ(partitionCounts[2], 1); +} + } // namespace } // namespace facebook::velox::exec::test diff --git a/velox/experimental/cudf/tests/TableScanTest.cpp b/velox/experimental/cudf/tests/TableScanTest.cpp index e86d82d3d3c1..8fe394b1786f 100644 --- a/velox/experimental/cudf/tests/TableScanTest.cpp +++ b/velox/experimental/cudf/tests/TableScanTest.cpp @@ -14,18 +14,19 @@ * limitations under the License. */ -#include "velox/experimental/cudf/connectors/parquet/ParquetConfig.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetConnector.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetConnectorSplit.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetDataSource.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h" -#include "velox/experimental/cudf/tests/utils/ParquetConnectorTestBase.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConfig.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnector.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveDataSource.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.h" +#include "velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/tests/FaultyFile.h" #include "velox/common/file/tests/FaultyFileSystem.h" #include "velox/common/memory/MemoryArbitrator.h" #include "velox/common/testutil/TestValue.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/exec/Exchange.h" #include "velox/exec/PlanNodeStats.h" @@ -37,6 +38,7 @@ #include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/expression/ExprToSubfieldFilter.h" #include "velox/type/Type.h" +#include "velox/type/tests/SubfieldFiltersBuilder.h" #include @@ -51,16 +53,16 @@ using namespace facebook::velox::cudf_velox; using namespace facebook::velox::cudf_velox::exec; using namespace facebook::velox::cudf_velox::exec::test; -class TableScanTest : public virtual ParquetConnectorTestBase { +class TableScanTest : public virtual CudfHiveConnectorTestBase { protected: void SetUp() override { - ParquetConnectorTestBase::SetUp(); + CudfHiveConnectorTestBase::SetUp(); ExchangeSource::factories().clear(); ExchangeSource::registerFactory(createLocalExchangeSource); } static void SetUpTestCase() { - ParquetConnectorTestBase::SetUpTestCase(); + CudfHiveConnectorTestBase::SetUpTestCase(); } std::vector makeVectors( @@ -68,11 +70,11 @@ class TableScanTest : public virtual ParquetConnectorTestBase { int32_t rowsPerVector, const RowTypePtr& rowType = nullptr) { auto inputs = rowType ? rowType : rowType_; - return ParquetConnectorTestBase::makeVectors(inputs, count, rowsPerVector); + return CudfHiveConnectorTestBase::makeVectors(inputs, count, rowsPerVector); } - Split makeParquetSplit(std::string path, int64_t splitWeight = 0) { - return Split(makeParquetConnectorSplit(std::move(path), splitWeight)); + Split makeCudfHiveSplit(std::string path, int64_t splitWeight = 0) { + return Split(makeCudfHiveConnectorSplit(std::move(path), splitWeight)); } std::shared_ptr assertQuery( @@ -94,7 +96,7 @@ class TableScanTest : public virtual ParquetConnectorTestBase { const PlanNodePtr& plan, const std::vector>& filePaths, const std::string& duckDbSql) { - return ParquetConnectorTestBase::assertQuery(plan, filePaths, duckDbSql); + return CudfHiveConnectorTestBase::assertQuery(plan, filePaths, duckDbSql); } // Run query with spill enabled. @@ -107,7 +109,7 @@ class TableScanTest : public virtual ParquetConnectorTestBase { .spillDirectory(spillDirectory) .config(core::QueryConfig::kSpillEnabled, false) .config(core::QueryConfig::kAggregationSpillEnabled, false) - .splits(makeParquetConnectorSplits(filePaths)) + .splits(makeCudfHiveConnectorSplits(filePaths)) .assertResults(duckDbSql); } @@ -132,17 +134,20 @@ class TableScanTest : public virtual ParquetConnectorTestBase { static std::unordered_map getTableScanRuntimeStats(const std::shared_ptr& task) { - VELOX_NYI("RuntimeStats not yet implemented for the cudf ParquetConnector"); + VELOX_NYI( + "RuntimeStats not yet implemented for the cudf CudfHiveConnector"); // return task->taskStats().pipelineStats[0].operatorStats[0].runtimeStats; } static int64_t getSkippedStridesStat(const std::shared_ptr& task) { - VELOX_NYI("RuntimeStats not yet implemented for the cudf ParquetConnector"); + VELOX_NYI( + "RuntimeStats not yet implemented for the cudf CudfHiveConnector"); // return getTableScanRuntimeStats(task)["skippedStrides"].sum; } static int64_t getSkippedSplitsStat(const std::shared_ptr& task) { - VELOX_NYI("RuntimeStats not yet implemented for the cudf ParquetConnector"); + VELOX_NYI( + "RuntimeStats not yet implemented for the cudf CudfHiveConnector"); // return getTableScanRuntimeStats(task)["skippedSplits"].sum; } @@ -195,7 +200,8 @@ TEST_F(TableScanTest, allColumns) { auto scanNodeId = plan->id(); auto it = planStats.find(scanNodeId); ASSERT_TRUE(it != planStats.end()); - ASSERT_TRUE(it->second.peakMemoryBytes > 0); + // TODO (dm): enable this test once we start to track gpu memory + // ASSERT_TRUE(it->second.peakMemoryBytes > 0); // Verifies there is no dynamic filter stats. ASSERT_TRUE(it->second.dynamicFilterStats.empty()); @@ -204,9 +210,9 @@ TEST_F(TableScanTest, allColumns) { // ASSERT_LT(0, it->second.customStats.at("ioWaitWallNanos").sum); }; - // Test scan all columns with ParquetConnectorSplits + // Test scan all columns with CudfHiveConnectorSplits { - auto splits = makeParquetConnectorSplits({filePath}); + auto splits = makeCudfHiveConnectorSplits({filePath}); testScanAllColumns(splits); } @@ -221,8 +227,9 @@ TEST_F(TableScanTest, allColumns) { splits; for (const auto& filePath : filePaths) { splits.push_back( - hive::HiveConnectorSplitBuilder(filePath->getPath()) - .connectorId(kParquetConnectorId) + facebook::velox::connector::hive::HiveConnectorSplitBuilder( + filePath->getPath()) + .connectorId(kCudfHiveConnectorId) .fileFormat(dwio::common::FileFormat::PARQUET) .build()); } @@ -264,7 +271,7 @@ TEST_F(TableScanTest, directBufferInputRawInputBytes) { auto task = AssertQueryBuilder(duckDbQueryRunner_) .plan(plan) - .splits(makeParquetConnectorSplits({filePath})) + .splits(makeCudfHiveConnectorSplits({filePath})) .queryCtx(queryCtx) .assertResults("SELECT c0, c2 FROM tmp"); @@ -275,11 +282,11 @@ TEST_F(TableScanTest, directBufferInputRawInputBytes) { auto it = planStats.find(scanNodeId); ASSERT_TRUE(it != planStats.end()); auto rawInputBytes = it->second.rawInputBytes; - // Reduced from 500 to 400 as cudf Parquet writer seems to be writing smaller + // Reduced from 500 to 400 as cudf CudfHive writer seems to be writing smaller // files. ASSERT_GE(rawInputBytes, 400); - // TableScan runtime stats not available with Parquet connector yet + // TableScan runtime stats not available with CudfHive connector yet #if 0 auto overreadBytes = getTableScanRuntimeStats(task).at("overreadBytes").sum; @@ -324,29 +331,17 @@ TEST_F(TableScanTest, filterPushdown) { createDuckDbTable(vectors); // c1 >= 0 or null and c3 is true - // common::SubfieldFilters subfieldFilters = - // SubfieldFiltersBuilder() - // .add("c1", greaterThanOrEqual(0, true)) - // .add("c3", std::make_unique(true, false)) - // .build(); - // convert subfieldFilters to a typed expression - // c1 >= 0 or null and c3 is true - auto c1Expr = std::make_shared( - BOOLEAN(), - "gte", - std::make_shared(BIGINT(), "c1"), - std::make_shared(BIGINT(), int64_t(0))); - - auto c3Expr = std::make_shared( - BOOLEAN(), - "eq", - std::make_shared(BOOLEAN(), "c3"), - std::make_shared(BOOLEAN(), true)); - - auto subfieldFilterExpr = - std::make_shared(BOOLEAN(), "and", c1Expr, c3Expr); + common::SubfieldFilters subfieldFilters = + common::test::SubfieldFiltersBuilder() + .add( + "c1", + std::make_unique( + int64_t(0), std::numeric_limits::max(), true)) + .add("c3", std::make_unique(true, false)) + .build(); + auto tableHandle = makeTableHandle( - "parquet_table", rowType, true, std::move(subfieldFilterExpr), nullptr); + "parquet_table", rowType, true, std::move(subfieldFilters), nullptr); auto assignments = facebook::velox::exec::test::HiveConnectorTestBase::allRegularColumns( @@ -422,3 +417,25 @@ TEST_F(TableScanTest, filterPushdown) { "SELECT count(*) FROM tmp"); #endif } + +TEST_F(TableScanTest, splitOffset) { + auto vectors = makeVectors(1, 10); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), vectors); + + auto plan = tableScanNode(); + + auto split = facebook::velox::connector::hive::HiveConnectorSplitBuilder( + filePath->getPath()) + .connectorId(kCudfHiveConnectorId) + .start(1) + .fileFormat(dwio::common::FileFormat::PARQUET) + .build(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .splits({split}) + .assertEmptyResults(), + "CudfHiveDataSource cannot process splits with non-zero offset"); +} diff --git a/velox/experimental/cudf/tests/TableWriteTest.cpp b/velox/experimental/cudf/tests/TableWriteTest.cpp index c25ab4432288..71c1b461cf87 100644 --- a/velox/experimental/cudf/tests/TableWriteTest.cpp +++ b/velox/experimental/cudf/tests/TableWriteTest.cpp @@ -14,14 +14,14 @@ * limitations under the License. */ -#include "velox/experimental/cudf/connectors/parquet/ParquetConfig.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetConnector.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetConnectorSplit.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetDataSource.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConfig.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnector.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveDataSource.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.h" #include "velox/experimental/cudf/exec/ToCudf.h" +#include "velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.h" #include "velox/experimental/cudf/tests/utils/CudfPlanBuilder.h" -#include "velox/experimental/cudf/tests/utils/ParquetConnectorTestBase.h" #include "folly/dynamic.h" #include "velox/common/base/Fs.h" @@ -35,6 +35,7 @@ #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" +#include #include #include @@ -55,13 +56,14 @@ using namespace facebook::velox::cudf_velox::exec::test; constexpr uint64_t kQueryMemoryCapacity = 512 * MB; +DEFINE_bool(velox_cudf_debug, false, "Enable debug printing"); + namespace { -static std::shared_ptr generateColumnStatsSpec( +static core::ColumnStatsSpec generateColumnStatsSpec( const std::string& name, const std::vector& groupingKeys, - AggregationNode::Step step, - const PlanNodePtr& source) { + AggregationNode::Step step) { core::TypedExprPtr inputField = std::make_shared(BIGINT(), name); auto callExpr = @@ -70,15 +72,7 @@ static std::shared_ptr generateColumnStatsSpec( std::vector aggregates = { core::AggregationNode::Aggregate{ callExpr, {{BIGINT()}}, nullptr, {}, {}}}; - return std::make_shared( - core::PlanNodeId(), - step, - groupingKeys, - std::vector{}, - aggregateNames, - aggregates, - false, // ignoreNullKeys - source); + return core::ColumnStatsSpec{groupingKeys, step, aggregateNames, aggregates}; } } // namespace @@ -152,7 +146,7 @@ struct TestParam { } }; -class TableWriteTest : public ParquetConnectorTestBase { +class TableWriteTest : public CudfHiveConnectorTestBase { protected: explicit TableWriteTest(uint64_t testValue) : testParam_(static_cast(testValue)), @@ -163,7 +157,7 @@ class TableWriteTest : public ParquetConnectorTestBase { commitStrategy_(testParam_.commitStrategy()), compressionKind_(testParam_.compressionKind()) { LOG(INFO) << testParam_.toString(); - if (cudfDebugEnabled()) { + if (FLAGS_velox_cudf_debug) { std::cout << testParam_.toString() << std::endl; } @@ -174,7 +168,7 @@ class TableWriteTest : public ParquetConnectorTestBase { } void SetUp() override { - ParquetConnectorTestBase::SetUp(); + CudfHiveConnectorTestBase::SetUp(); } std::shared_ptr assertQueryWithWriterConfigs( @@ -185,7 +179,7 @@ class TableWriteTest : public ParquetConnectorTestBase { std::vector splits; for (const auto& filePath : filePaths) { splits.push_back(facebook::velox::exec::Split( - makeParquetConnectorSplit(filePath->getPath()))); + makeCudfHiveConnectorSplit(filePath->getPath()))); } if (!spillEnabled) { return AssertQueryBuilder(plan, duckDbQueryRunner_) @@ -264,19 +258,19 @@ class TableWriteTest : public ParquetConnectorTestBase { } std::vector> - makeParquetConnectorSplits( + makeCudfHiveConnectorSplits( const std::shared_ptr& directoryPath) { - return makeParquetConnectorSplits(directoryPath->getPath()); + return makeCudfHiveConnectorSplits(directoryPath->getPath()); } std::vector> - makeParquetConnectorSplits(const std::string& directoryPath) { + makeCudfHiveConnectorSplits(const std::string& directoryPath) { std::vector> splits; for (auto& path : fs::recursive_directory_iterator(directoryPath)) { if (path.is_regular_file()) { - splits.push_back(ParquetConnectorTestBase::makeParquetConnectorSplits( + splits.push_back(CudfHiveConnectorTestBase::makeCudfHiveConnectorSplits( path.path().string(), 1)[0]); } } @@ -298,12 +292,12 @@ class TableWriteTest : public ParquetConnectorTestBase { // Builds and returns the parquet splits from the list of files with one split // per each file. std::vector> - makeParquetConnectorSplits( + makeCudfHiveConnectorSplits( const std::vector& filePaths) { std::vector> splits; for (const auto& filePath : filePaths) { - splits.push_back(ParquetConnectorTestBase::makeParquetConnectorSplits( + splits.push_back(CudfHiveConnectorTestBase::makeCudfHiveConnectorSplits( filePath.string(), 1)[0]); } return splits; @@ -312,7 +306,7 @@ class TableWriteTest : public ParquetConnectorTestBase { std::vector makeVectors( int32_t numVectors, int32_t rowsPerVector) { - return ParquetConnectorTestBase::makeVectors( + return CudfHiveConnectorTestBase::makeVectors( rowType_, numVectors, rowsPerVector); } @@ -366,13 +360,13 @@ class TableWriteTest : public ParquetConnectorTestBase { // Helper method to return InsertTableHandle. std::shared_ptr createInsertTableHandle( const RowTypePtr& outputRowType, - const cudf_velox::connector::parquet::LocationHandle::TableType& + const cudf_velox::connector::hive::LocationHandle::TableType& outputTableType, const std::string& outputDirectoryPath, const std::optional compressionKind = {}) { return std::make_shared( - kParquetConnectorId, - makeParquetInsertTableHandle( + kCudfHiveConnectorId, + makeCudfHiveInsertTableHandle( outputRowType->names(), outputRowType->children(), makeLocationHandle(outputDirectoryPath, outputTableType), @@ -386,12 +380,29 @@ class TableWriteTest : public ParquetConnectorTestBase { const std::string& outputDirectoryPath, const std::optional compressionKind = {}, int numTableWriters = 1, - const cudf_velox::connector::parquet::LocationHandle::TableType& + const cudf_velox::connector::hive::LocationHandle::TableType& outputTableType = - cudf_velox::connector::parquet::LocationHandle::TableType::kNew, + cudf_velox::connector::hive::LocationHandle::TableType::kNew, const CommitStrategy& outputCommitStrategy = CommitStrategy::kNoCommit, bool aggregateResult = true, std::shared_ptr aggregationNode = nullptr) { + std::optional columnStatsSpec = std::nullopt; + if (aggregationNode != nullptr) { + // Convert AggregationNode to ColumnStatsSpec + VELOX_CHECK(!aggregationNode->ignoreNullKeys()); + VELOX_CHECK(!aggregationNode->groupId().has_value()); + VELOX_CHECK(!aggregationNode->isPreGrouped()); + VELOX_CHECK(aggregationNode->globalGroupingSets().empty()); + VELOX_CHECK(!aggregationNode->aggregateNames().empty()); + VELOX_CHECK_EQ( + aggregationNode->aggregateNames().size(), + aggregationNode->aggregates().size()); + columnStatsSpec = core::ColumnStatsSpec{ + aggregationNode->groupingKeys(), + aggregationNode->step(), + aggregationNode->aggregateNames(), + aggregationNode->aggregates()}; + } return createInsertPlan( inputPlan, inputPlan.planNode()->outputType(), @@ -402,7 +413,7 @@ class TableWriteTest : public ParquetConnectorTestBase { outputTableType, outputCommitStrategy, aggregateResult, - aggregationNode); + columnStatsSpec); } PlanNodePtr createInsertPlan( @@ -412,12 +423,12 @@ class TableWriteTest : public ParquetConnectorTestBase { const std::string& outputDirectoryPath, const std::optional compressionKind = {}, int numTableWriters = 1, - const cudf_velox::connector::parquet::LocationHandle::TableType& + const cudf_velox::connector::hive::LocationHandle::TableType& outputTableType = - cudf_velox::connector::parquet::LocationHandle::TableType::kNew, + cudf_velox::connector::hive::LocationHandle::TableType::kNew, const CommitStrategy& outputCommitStrategy = CommitStrategy::kNoCommit, bool aggregateResult = true, - std::shared_ptr aggregationNode = nullptr) { + std::optional columnStatsSpec = std::nullopt) { VELOX_CHECK( numTableWriters == 1, "Multiple CudfTableWriters not yet supported"); return createInsertPlanWithSingleWriter( @@ -429,7 +440,7 @@ class TableWriteTest : public ParquetConnectorTestBase { outputTableType, outputCommitStrategy, aggregateResult, - aggregationNode); + columnStatsSpec); } PlanNodePtr createInsertPlanWithSingleWriter( @@ -438,18 +449,18 @@ class TableWriteTest : public ParquetConnectorTestBase { const RowTypePtr& tableRowType, const std::string& outputDirectoryPath, const std::optional compressionKind, - const cudf_velox::connector::parquet::LocationHandle::TableType& + const cudf_velox::connector::hive::LocationHandle::TableType& outputTableType, const CommitStrategy& outputCommitStrategy, bool aggregateResult, - std::shared_ptr aggregationNode) { + std::optional columnStatsSpec) { const bool addScaleWriterExchange = false; auto insertPlan = inputPlan; insertPlan .addNode(addCudfTableWriter( inputRowType, tableRowType->names(), - aggregationNode, + columnStatsSpec, createInsertTableHandle( tableRowType, outputTableType, @@ -481,7 +492,7 @@ class TableWriteTest : public ParquetConnectorTestBase { return inputNames; } - // Parameter partitionName is string formatted in the Parquet style + // Parameter partitionName is string formatted in the CudfHive style // key1=value1/key2=value2/... Parameter partitionTypes are types of partition // keys in the same order as in partitionName.The return value is a SQL // predicate with values single quoted for string and date and not quoted for @@ -579,7 +590,7 @@ class TableWriteTest : public ParquetConnectorTestBase { core::PlanNodeId tableWriteNodeId_; }; -class BasicTableWriteTest : public ParquetConnectorTestBase {}; +class BasicTableWriteTest : public CudfHiveConnectorTestBase {}; TEST_F(BasicTableWriteTest, roundTrip) { vector_size_t size = 1'000; @@ -598,14 +609,14 @@ TEST_F(BasicTableWriteTest, roundTrip) { auto plan = PlanBuilder() .startTableScan() .outputType(rowType) - .tableHandle(ParquetConnectorTestBase::makeTableHandle()) + .tableHandle(CudfHiveConnectorTestBase::makeTableHandle()) .endTableScan() .addNode(cudfTableWrite(targetDirectoryPath->getPath())) .planNode(); auto results = AssertQueryBuilder(plan) - .split(makeParquetConnectorSplit(sourceFilePath->getPath())) + .split(makeCudfHiveConnectorSplit(sourceFilePath->getPath())) .copyResults(pool()); ASSERT_EQ(2, results->size()); @@ -633,12 +644,12 @@ TEST_F(BasicTableWriteTest, roundTrip) { plan = PlanBuilder() .startTableScan() .outputType(rowType) - .tableHandle(ParquetConnectorTestBase::makeTableHandle()) + .tableHandle(CudfHiveConnectorTestBase::makeTableHandle()) .endTableScan() .planNode(); auto copy = AssertQueryBuilder(plan) - .split(makeParquetConnectorSplit(fmt::format( + .split(makeCudfHiveConnectorSplit(fmt::format( "{}/{}", targetDirectoryPath->getPath(), writeFileName))) .copyResults(pool()); assertEqualResults({data}, {copy}); @@ -668,11 +679,11 @@ TEST_F(BasicTableWriteTest, targetFileName) { plan = PlanBuilder() .startTableScan() .outputType(asRowType(data->type())) - .tableHandle(ParquetConnectorTestBase::makeTableHandle()) + .tableHandle(CudfHiveConnectorTestBase::makeTableHandle()) .endTableScan() .planNode(); AssertQueryBuilder(plan) - .split(makeParquetConnectorSplit( + .split(makeCudfHiveConnectorSplit( fmt::format("{}/{}", directory->getPath(), kFileName))) .assertResults(data); } @@ -725,7 +736,7 @@ TEST_P(UnpartitionedTableWriterTest, differentCompression) { outputDirectory->getPath(), compressionKind, numTableWriterCount_, - cudf_velox::connector::parquet::LocationHandle::TableType::kNew), + cudf_velox::connector::hive::LocationHandle::TableType::kNew), "Unsupported compression type: CompressionKind_MAX"); return; } @@ -735,7 +746,7 @@ TEST_P(UnpartitionedTableWriterTest, differentCompression) { outputDirectory->getPath(), compressionKind, numTableWriterCount_, - cudf_velox::connector::parquet::LocationHandle::TableType::kNew); + cudf_velox::connector::hive::LocationHandle::TableType::kNew); auto result = AssertQueryBuilder(plan) .config( @@ -747,11 +758,11 @@ TEST_P(UnpartitionedTableWriterTest, differentCompression) { } } -// Test not really needed as we always write a TableType::kNew table in Parquet +// Test not really needed as we always write a TableType::kNew table in CudfHive // DataSink TEST_P(UnpartitionedTableWriterTest, immutableSettings) { struct { - cudf_velox::connector::parquet::LocationHandle::TableType dataType; + cudf_velox::connector::hive::LocationHandle::TableType dataType; bool immutableFilesEnabled; bool expectedInsertSuccees; @@ -763,10 +774,10 @@ TEST_P(UnpartitionedTableWriterTest, immutableSettings) { expectedInsertSuccees); } } testSettings[] = { - {cudf_velox::connector::parquet::LocationHandle::TableType::kNew, + {cudf_velox::connector::hive::LocationHandle::TableType::kNew, true, true}, - {cudf_velox::connector::parquet::LocationHandle::TableType::kNew, + {cudf_velox::connector::hive::LocationHandle::TableType::kNew, false, true}}; @@ -777,7 +788,7 @@ TEST_P(UnpartitionedTableWriterTest, immutableSettings) { testData.immutableFilesEnabled ? "true" : "false"}}; std::shared_ptr config{ std::make_shared(std::move(propFromFile))}; - resetParquetConnector(config); + resetCudfHiveConnector(config); auto input = makeVectors(10, 10); auto outputDirectory = TempDirectoryPath::create(); @@ -792,7 +803,7 @@ TEST_P(UnpartitionedTableWriterTest, immutableSettings) { if (!testData.expectedInsertSuccees) { VELOX_ASSERT_THROW( AssertQueryBuilder(plan).copyResults(pool()), - "Parquet tables are immutable."); + "CudfHive tables are immutable."); } else { auto result = AssertQueryBuilder(plan) .config( diff --git a/velox/experimental/cudf/tests/sparksql/CMakeLists.txt b/velox/experimental/cudf/tests/sparksql/CMakeLists.txt index 8c4d8124d6da..baa31524822d 100644 --- a/velox/experimental/cudf/tests/sparksql/CMakeLists.txt +++ b/velox/experimental/cudf/tests/sparksql/CMakeLists.txt @@ -36,3 +36,28 @@ target_link_libraries( GTest::gtest GTest::gtest_main ) + +add_executable(velox_cudf_spark_filter_project_test FilterProjectTest.cpp Main.cpp) + +add_test( + NAME velox_cudf_spark_filter_project_test + COMMAND velox_cudf_spark_filter_project_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties( + velox_cudf_spark_filter_project_test + PROPERTIES LABELS cuda_driver TIMEOUT 3000 +) + +target_link_libraries( + velox_cudf_spark_filter_project_test + velox_cudf_exec + velox_exec_test_lib + velox_vector_test_lib + velox_functions_spark + velox_vector_fuzzer + gflags::gflags + GTest::gtest + GTest::gtest_main +) diff --git a/velox/experimental/cudf/tests/sparksql/FilterProjectTest.cpp b/velox/experimental/cudf/tests/sparksql/FilterProjectTest.cpp new file mode 100644 index 000000000000..fd90ea423ec8 --- /dev/null +++ b/velox/experimental/cudf/tests/sparksql/FilterProjectTest.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/experimental/cudf/exec/CudfFilterProject.h" +#include "velox/experimental/cudf/exec/ToCudf.h" + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/common/tests/utils/BatchMaker.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +using namespace facebook::velox::exec::test; +using namespace facebook::velox; + +namespace { + +class CudfFilterProjectTest + : public facebook::velox::functions::sparksql::test::SparkFunctionBaseTest { + protected: + static void SetUpTestCase() { + facebook::velox::functions::sparksql::test::SparkFunctionBaseTest:: + SetUpTestCase(); + cudf_velox::registerCudf(); + } + + static void TearDownTestCase() { + facebook::velox::functions::sparksql::test::SparkFunctionBaseTest:: + TearDownTestCase(); + cudf_velox::unregisterCudf(); + } + + CudfFilterProjectTest() { + options_.parseIntegerAsBigint = false; + } +}; + +TEST_F(CudfFilterProjectTest, hashWithSeed) { + auto input = makeFlatVector({INT64_MAX, INT64_MIN}); + auto data = makeRowVector({input}); + auto hashPlan = PlanBuilder() + .setParseOptions(options_) + .values({data}) + .project({"hash_with_seed(42, c0) AS c1"}) + .planNode(); + auto hashResults = AssertQueryBuilder(hashPlan).copyResults(pool()); + + auto expected = makeRowVector({ + makeFlatVector({ + 1049813396, + 1800792340, + }), + }); + facebook::velox::test::assertEqualVectors(expected, hashResults); +} + +TEST_F(CudfFilterProjectTest, hashWithSeedMultiColumns) { + auto input = makeFlatVector({INT64_MAX, INT64_MIN}); + auto data = makeRowVector({input, input}); + auto hashPlan = PlanBuilder() + .setParseOptions(options_) + .values({data}) + .project({"hash_with_seed(42, c0, c1) AS c2"}) + .planNode(); + auto hashResults = AssertQueryBuilder(hashPlan).copyResults(pool()); + + auto expected = makeRowVector({ + makeFlatVector({ + -864217843, + 821064941, + }), + }); + facebook::velox::test::assertEqualVectors(expected, hashResults); +} +} // namespace diff --git a/velox/experimental/cudf/tests/utils/CMakeLists.txt b/velox/experimental/cudf/tests/utils/CMakeLists.txt index 3a782713aafb..eb1f8cb52d96 100644 --- a/velox/experimental/cudf/tests/utils/CMakeLists.txt +++ b/velox/experimental/cudf/tests/utils/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_library(velox_cudf_exec_test_lib ParquetConnectorTestBase.cpp CudfPlanBuilder.cpp) +add_library(velox_cudf_exec_test_lib CudfHiveConnectorTestBase.cpp CudfPlanBuilder.cpp) set_target_properties(velox_cudf_exec_test_lib PROPERTIES CUDA_ARCHITECTURES native) @@ -30,6 +30,6 @@ target_link_libraries( velox_parse_parser velox_duckdb_conversion velox_file_test_utils - velox_cudf_parquet_connector + velox_cudf_hive_connector velox_aggregates ) diff --git a/velox/experimental/cudf/tests/utils/ParquetConnectorTestBase.cpp b/velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.cpp similarity index 71% rename from velox/experimental/cudf/tests/utils/ParquetConnectorTestBase.cpp rename to velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.cpp index 937df3345dc9..cdeaba197953 100644 --- a/velox/experimental/cudf/tests/utils/ParquetConnectorTestBase.cpp +++ b/velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.cpp @@ -14,16 +14,18 @@ * limitations under the License. */ +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnector.h" +#include "velox/experimental/cudf/exec/ToCudf.h" #include "velox/experimental/cudf/exec/VeloxCudfInterop.h" -#include "velox/experimental/cudf/tests/utils/ParquetConnectorTestBase.h" -#include "velox/experimental/cudf/vector/CudfVector.h" +#include "velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.h" #include "velox/common/base/Exceptions.h" #include "velox/common/file/FileSystems.h" #include "velox/common/file/tests/FaultyFileSystem.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/dwio/common/FileSink.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" -#include "velox/dwio/dwrf/writer/FlushPolicy.h" +#include "velox/exec/Driver.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include @@ -60,42 +62,52 @@ void fillColumnNames( } // namespace -ParquetConnectorTestBase::ParquetConnectorTestBase() { +using facebook::velox::connector::hive::HiveConnectorFactory; + +CudfHiveConnectorTestBase::CudfHiveConnectorTestBase() { filesystems::registerLocalFileSystem(); tests::utils::registerFaultyFileSystem(); } -void ParquetConnectorTestBase::SetUp() { +void CudfHiveConnectorTestBase::SetUp() { OperatorTestBase::SetUp(); - facebook::velox::connector::parquet::ParquetConnectorFactory factory; - auto parquetConnector = factory.newConnector( - kParquetConnectorId, + + // Register cudf to enable the CudfDatasource creation from CudfHiveConnector + facebook::velox::cudf_velox::registerCudf(); + + // Register Hive connector + facebook::velox::cudf_velox::connector::hive::CudfHiveConnectorFactory + factory; + auto hiveConnector = factory.newConnector( + kCudfHiveConnectorId, std::make_shared( std::unordered_map()), ioExecutor_.get()); - facebook::velox::connector::registerConnector(parquetConnector); + facebook::velox::connector::registerConnector(hiveConnector); dwio::common::registerFileSinks(); } -void ParquetConnectorTestBase::TearDown() { +void CudfHiveConnectorTestBase::TearDown() { // Make sure all pending loads are finished or cancelled before unregister // connector. ioExecutor_.reset(); - facebook::velox::connector::unregisterConnector(kParquetConnectorId); + facebook::velox::connector::unregisterConnector(kCudfHiveConnectorId); + facebook::velox::cudf_velox::unregisterCudf(); OperatorTestBase::TearDown(); } -void ParquetConnectorTestBase::resetParquetConnector( +void CudfHiveConnectorTestBase::resetCudfHiveConnector( const std::shared_ptr& config) { - facebook::velox::connector::unregisterConnector(kParquetConnectorId); + facebook::velox::connector::unregisterConnector(kCudfHiveConnectorId); - facebook::velox::connector::parquet::ParquetConnectorFactory factory; - auto parquetConnector = - factory.newConnector(kParquetConnectorId, config, ioExecutor_.get()); - facebook::velox::connector::registerConnector(parquetConnector); + facebook::velox::cudf_velox::connector::hive::CudfHiveConnectorFactory + factory; + auto hiveConnector = + factory.newConnector(kCudfHiveConnectorId, config, ioExecutor_.get()); + facebook::velox::connector::registerConnector(hiveConnector); } -std::vector ParquetConnectorTestBase::makeVectors( +std::vector CudfHiveConnectorTestBase::makeVectors( const RowTypePtr& rowType, int32_t numVectors, int32_t rowsPerVector) { @@ -109,17 +121,17 @@ std::vector ParquetConnectorTestBase::makeVectors( } std::shared_ptr -ParquetConnectorTestBase::assertQuery( +CudfHiveConnectorTestBase::assertQuery( const core::PlanNodePtr& plan, const std::vector< std::shared_ptr>& filePaths, const std::string& duckDbSql) { return OperatorTestBase::assertQuery( - plan, makeParquetConnectorSplits(filePaths), duckDbSql); + plan, makeCudfHiveConnectorSplits(filePaths), duckDbSql); } std::shared_ptr -ParquetConnectorTestBase::assertQuery( +CudfHiveConnectorTestBase::assertQuery( const facebook::velox::core::PlanNodePtr& plan, const std::vector< std::shared_ptr>& splits, @@ -135,7 +147,7 @@ ParquetConnectorTestBase::assertQuery( } std::vector> -ParquetConnectorTestBase::makeFilePaths(int count) { +CudfHiveConnectorTestBase::makeFilePaths(int count) { std::vector> filePaths; filePaths.reserve(count); @@ -145,7 +157,7 @@ ParquetConnectorTestBase::makeFilePaths(int count) { return filePaths; } -void ParquetConnectorTestBase::writeToFile( +void CudfHiveConnectorTestBase::writeToFile( const std::string& filePath, const std::vector& vectors, std::string prefix) { @@ -186,7 +198,7 @@ void ParquetConnectorTestBase::writeToFile( writer.close(); } -void ParquetConnectorTestBase::writeToFile( +void CudfHiveConnectorTestBase::writeToFile( const std::string& filePath, RowVectorPtr vector, std::string prefix) { @@ -204,40 +216,22 @@ void ParquetConnectorTestBase::writeToFile( cudf::io::write_parquet(options); } -std::unique_ptr -ParquetConnectorTestBase::makeColumnHandle( - const std::string& name, - const TypePtr& type, - const std::vector& children) { - return std::make_unique( - name, type, cudf::data_type(cudf::type_id::EMPTY), children); -} - -std::unique_ptr -ParquetConnectorTestBase::makeColumnHandle( - const std::string& name, - const TypePtr& type, - const cudf::data_type data_type, - const std::vector& children) { - return std::make_unique( - name, type, data_type, children); -} - std::vector> -ParquetConnectorTestBase::makeParquetConnectorSplits( +CudfHiveConnectorTestBase::makeCudfHiveConnectorSplits( const std::vector< std::shared_ptr>& filePaths) { std::vector> splits; for (const auto& filePath : filePaths) { - splits.push_back(makeParquetConnectorSplit(filePath->getPath())); + splits.push_back(makeCudfHiveConnectorSplit(filePath->getPath())); } return splits; } -std::vector> -ParquetConnectorTestBase::makeParquetConnectorSplits( +std::vector< + std::shared_ptr> +CudfHiveConnectorTestBase::makeCudfHiveConnectorSplits( const std::string& filePath, uint32_t splitCount) { auto file = @@ -245,46 +239,53 @@ ParquetConnectorTestBase::makeParquetConnectorSplits( const int64_t fileSize = file->size(); // Take the upper bound. const int64_t splitSize = std::ceil((fileSize) / splitCount); - std::vector> + std::vector< + std::shared_ptr> splits; // Add all the splits. for (int i = 0; i < splitCount; i++) { - auto split = ParquetConnectorSplitBuilder(filePath).build(); + auto split = + facebook::velox::connector::hive::HiveConnectorSplitBuilder(filePath) + .connectorId(kCudfHiveConnectorId) + .fileFormat(facebook::velox::dwio::common::FileFormat::PARQUET) + .build(); splits.push_back(std::move(split)); } return splits; } -std::shared_ptr -ParquetConnectorTestBase::makeParquetConnectorSplit( +std::shared_ptr +CudfHiveConnectorTestBase::makeCudfHiveConnectorSplit( const std::string& filePath, int64_t splitWeight) { - return ParquetConnectorSplitBuilder(filePath) + return facebook::velox::connector::hive::HiveConnectorSplitBuilder(filePath) + .connectorId(kCudfHiveConnectorId) + .fileFormat(facebook::velox::dwio::common::FileFormat::PARQUET) .splitWeight(splitWeight) .build(); } // static -std::shared_ptr -ParquetConnectorTestBase::makeParquetInsertTableHandle( +std::shared_ptr +CudfHiveConnectorTestBase::makeCudfHiveInsertTableHandle( const std::vector& tableColumnNames, const std::vector& tableColumnTypes, - std::shared_ptr locationHandle, + std::shared_ptr locationHandle, const std::optional compressionKind, const std::unordered_map& serdeParameters, const std::shared_ptr& writerOptions) { - std::vector> + std::vector> columnHandles; for (int i = 0; i < tableColumnNames.size(); ++i) { columnHandles.push_back( - std::make_shared( + std::make_shared( tableColumnNames.at(i), tableColumnTypes.at(i), cudf::data_type{veloxToCudfTypeId(tableColumnTypes.at(i))})); } - return std::make_shared( + return std::make_shared( columnHandles, locationHandle, compressionKind, diff --git a/velox/experimental/cudf/tests/utils/ParquetConnectorTestBase.h b/velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.h similarity index 62% rename from velox/experimental/cudf/tests/utils/ParquetConnectorTestBase.h rename to velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.h index 092af2a19b70..d67719b1a6e3 100644 --- a/velox/experimental/cudf/tests/utils/ParquetConnectorTestBase.h +++ b/velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.h @@ -16,12 +16,13 @@ #pragma once -#include "velox/experimental/cudf/connectors/parquet/ParquetConfig.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetConnector.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetDataSink.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetDataSource.h" -#include "velox/experimental/cudf/connectors/parquet/ParquetTableHandle.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConfig.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnector.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveDataSink.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveDataSource.h" +#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/TableHandle.h" #include "velox/exec/Operator.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/TempFilePath.h" @@ -29,21 +30,21 @@ namespace facebook::velox::cudf_velox::exec::test { -static const std::string kParquetConnectorId = "test-parquet"; +static const std::string kCudfHiveConnectorId = "test-cudf-hive"; using ColumnHandleMap = std::unordered_map< std::string, std::shared_ptr>; -class ParquetConnectorTestBase +class CudfHiveConnectorTestBase : public facebook::velox::exec::test::OperatorTestBase { public: - ParquetConnectorTestBase(); + CudfHiveConnectorTestBase(); void SetUp() override; void TearDown() override; - void resetParquetConnector( + void resetCudfHiveConnector( const std::shared_ptr& config); void writeToFile( @@ -81,34 +82,34 @@ class ParquetConnectorTestBase static std::vector> makeFilePaths(int count); - static std::shared_ptr< - facebook::velox::cudf_velox::connector::parquet::ParquetConnectorSplit> - makeParquetConnectorSplit( + static std::shared_ptr + makeCudfHiveConnectorSplit( const std::string& filePath, int64_t splitWeight = 0); static std::vector< std::shared_ptr> - makeParquetConnectorSplits( + makeCudfHiveConnectorSplits( const std::vector< std::shared_ptr>& filePaths); - static std::vector> - makeParquetConnectorSplits(const std::string& filePath, uint32_t splitCount); + static std::vector< + std::shared_ptr> + makeCudfHiveConnectorSplits(const std::string& filePath, uint32_t splitCount); - static std::shared_ptr + static std::shared_ptr makeTableHandle( const std::string& tableName = "parquet_table", const RowTypePtr& dataColumns = nullptr, bool filterPushdownEnabled = false, - const core::TypedExprPtr& subfieldFilterExpr = nullptr, + common::SubfieldFilters subfieldFilters = {}, const core::TypedExprPtr& remainingFilterExpr = nullptr) { - return std::make_shared( - kParquetConnectorId, + return std::make_shared( + kCudfHiveConnectorId, tableName, filterPushdownEnabled, - subfieldFilterExpr, + std::move(subfieldFilters), remainingFilterExpr, dataColumns); } @@ -116,47 +117,56 @@ class ParquetConnectorTestBase /// @param name Column name. /// @param type Column type. /// @param Required subfields of this column. - static std::unique_ptr + static std::shared_ptr makeColumnHandle( const std::string& name, const TypePtr& type, - const std::vector& children); + facebook::velox::connector::hive::HiveColumnHandle::ColumnType + columnType = facebook::velox::connector::hive::HiveColumnHandle:: + ColumnType::kRegular, + const std::vector& requiredSubfields = + {}) { + return std::make_shared( + name, + columnType, + type, + type, + std::vector{}); + } /// @param name Column name. /// @param type Column type. /// @param type cudf column type. /// @param Required subfields of this column. - static std::unique_ptr + static std::unique_ptr makeColumnHandle( const std::string& name, const TypePtr& type, const cudf::data_type data_type, - const std::vector& children); + const std::vector& children); /// @param targetDirectory Final directory of the target table. /// @param tableType Whether to create a new table. - static std::shared_ptr makeLocationHandle( + static std::shared_ptr makeLocationHandle( std::string targetDirectory) { - return std::make_shared( - targetDirectory, - connector::parquet::LocationHandle::TableType::kNew, - ""); + return std::make_shared( + targetDirectory, connector::hive::LocationHandle::TableType::kNew, ""); } /// @param targetDirectory Final directory of the target table. /// @param tableType Whether to create a new table, insert into an existing /// table, or write a temporary table. /// @param targetDirectory Final file name of the target table . - static std::shared_ptr makeLocationHandle( + static std::shared_ptr makeLocationHandle( std::string targetDirectory, - connector::parquet::LocationHandle::TableType tableType = - connector::parquet::LocationHandle::TableType::kNew, + connector::hive::LocationHandle::TableType tableType = + connector::hive::LocationHandle::TableType::kNew, std::string targetFileName = "") { - return std::make_shared( + return std::make_shared( targetDirectory, tableType, targetFileName); } - /// Build a ParquetInsertTableHandle. + /// Build a CudfHiveInsertTableHandle. /// @param tableColumnNames Column names of the target table. Corresponding /// type of tableColumnNames[i] is tableColumnTypes[i]. /// @param tableColumnTypes Column types of the target table. Corresponding @@ -164,25 +174,25 @@ class ParquetConnectorTestBase /// @param locationHandle Location handle for the table write. /// @param compressionKind compression algorithm to use for table write. /// @param serdeParameters Table writer configuration parameters. - static std::shared_ptr - makeParquetInsertTableHandle( + static std::shared_ptr + makeCudfHiveInsertTableHandle( const std::vector& tableColumnNames, const std::vector& tableColumnTypes, - std::shared_ptr locationHandle, + std::shared_ptr locationHandle, const std::optional compressionKind = {}, const std::unordered_map& serdeParameters = {}, const std::shared_ptr& writerOptions = nullptr); }; -/// Same as connector::parquet::ParquetConnectorBuilder, except that this -/// defaults connectorId to kParquetConnectorId. -class ParquetConnectorSplitBuilder - : public connector::parquet::ParquetConnectorSplitBuilder { +/// Same as connector::hive::CudfHiveConnectorBuilder, except that this +/// defaults connectorId to kCudfHiveConnectorId. +class CudfHiveConnectorSplitBuilder + : public connector::hive::CudfHiveConnectorSplitBuilder { public: - explicit ParquetConnectorSplitBuilder(std::string filePath) - : connector::parquet::ParquetConnectorSplitBuilder(filePath) { - connectorId(kParquetConnectorId); + explicit CudfHiveConnectorSplitBuilder(std::string filePath) + : connector::hive::CudfHiveConnectorSplitBuilder(filePath) { + connectorId(kCudfHiveConnectorId); } }; diff --git a/velox/experimental/cudf/tests/utils/CudfPlanBuilder.cpp b/velox/experimental/cudf/tests/utils/CudfPlanBuilder.cpp index 7f444f343577..96856a43b4ed 100644 --- a/velox/experimental/cudf/tests/utils/CudfPlanBuilder.cpp +++ b/velox/experimental/cudf/tests/utils/CudfPlanBuilder.cpp @@ -25,7 +25,7 @@ namespace facebook::velox::cudf_velox::exec::test { std::function addCudfTableWriter( const RowTypePtr& inputColumns, const std::vector& tableColumnNames, - const std::shared_ptr& aggregationNode, + const std::optional& columnStatsSpec, const std::shared_ptr& insertHandle, facebook::velox::connector::CommitStrategy commitStrategy) { return [=](core::PlanNodeId nodeId, @@ -34,10 +34,10 @@ std::function addCudfTableWriter( nodeId, inputColumns, tableColumnNames, - aggregationNode, + columnStatsSpec, insertHandle, false, - TableWriteTraits::outputType(aggregationNode), + TableWriteTraits::outputType(columnStatsSpec), commitStrategy, std::move(source)); }; @@ -46,14 +46,14 @@ std::function addCudfTableWriter( std::function cudfTableWrite( const std::string& outputDirectoryPath, const dwio::common::FileFormat fileFormat, - const std::shared_ptr& aggregationNode, + const std::optional& columnStatsSpec, const std::shared_ptr& options, const std::string& outputFileName) { return cudfTableWrite( outputDirectoryPath, fileFormat, - aggregationNode, - kParquetConnectorId, + columnStatsSpec, + kCudfHiveConnectorId, {}, options, outputFileName); @@ -62,7 +62,7 @@ std::function cudfTableWrite( std::function cudfTableWrite( const std::string& outputDirectoryPath, const dwio::common::FileFormat fileFormat, - const std::shared_ptr& aggregationNode, + const std::optional& columnStatsSpec, const std::string_view& connectorId, const std::unordered_map& serdeParameters, const std::shared_ptr& options, @@ -73,12 +73,13 @@ std::function cudfTableWrite( core::PlanNodePtr source) -> core::PlanNodePtr { auto rowType = schema ? schema : source->outputType(); - auto locationHandle = ParquetConnectorTestBase::makeLocationHandle( + auto locationHandle = CudfHiveConnectorTestBase::makeLocationHandle( outputDirectoryPath, - cudf_velox::connector::parquet::LocationHandle::TableType::kNew, + cudf_velox::connector::hive::LocationHandle::TableType::kNew, outputFileName); - auto parquetHandle = ParquetConnectorTestBase::makeParquetInsertTableHandle( - rowType->names(), rowType->children(), locationHandle, compression); + auto parquetHandle = + CudfHiveConnectorTestBase::makeCudfHiveInsertTableHandle( + rowType->names(), rowType->children(), locationHandle, compression); auto insertHandle = std::make_shared( std::string(connectorId), parquetHandle); @@ -86,10 +87,10 @@ std::function cudfTableWrite( nodeId, rowType, rowType->names(), - aggregationNode, + columnStatsSpec, insertHandle, false, - TableWriteTraits::outputType(aggregationNode), + TableWriteTraits::outputType(columnStatsSpec), facebook::velox::connector::CommitStrategy::kNoCommit, std::move(source)); }; diff --git a/velox/experimental/cudf/tests/utils/CudfPlanBuilder.h b/velox/experimental/cudf/tests/utils/CudfPlanBuilder.h index 9626f60b57ae..f2cf39e5ede7 100644 --- a/velox/experimental/cudf/tests/utils/CudfPlanBuilder.h +++ b/velox/experimental/cudf/tests/utils/CudfPlanBuilder.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "velox/experimental/cudf/connectors/parquet/ParquetDataSink.h" -#include "velox/experimental/cudf/tests/utils/ParquetConnectorTestBase.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveDataSink.h" +#include "velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.h" #include "velox/dwio/common/Options.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -33,21 +33,21 @@ using namespace facebook::velox::common::test; using namespace facebook::velox::common::testutil; using namespace facebook::velox::dwio::common; -// Adds a TableWriter node to write all input columns into a Parquet table. +// Adds a TableWriter node to write all input columns into a CudfHive table. std::function addCudfTableWriter( const RowTypePtr& inputColumns, const std::vector& tableColumnNames, - const std::shared_ptr& aggregationNode, + const std::optional& columnStatsSpec, const std::shared_ptr& insertHandle, facebook::velox::connector::CommitStrategy commitStrategy = facebook::velox::connector::CommitStrategy::kNoCommit); /// Adds a TableWriteNode to write all input columns into an un-partitioned -/// un-bucketed Parquet table without compression. +/// un-bucketed CudfHive table without compression. /// /// @param outputDirectoryPath Path to a directory to write data to. /// @param fileFormat File format to use for the written data. -/// @param aggregationNode AggregationNode for column statistics collection +/// @param columnStatsSpec ColumnStatsSpec for column statistics collection /// during write. /// @param polymorphic options object to be passed to the writer. /// write, supported aggregation types vary for different column types. @@ -63,16 +63,16 @@ std::function cudfTableWrite( const std::string& outputDirectoryPath, const dwio::common::FileFormat fileFormat = dwio::common::FileFormat::PARQUET, - const std::shared_ptr& aggregationNode = nullptr, + const std::optional& columnStatsSpec = std::nullopt, const std::shared_ptr& options = nullptr, const std::string& outputFileName = ""); -/// Adds a TableWriteNode to write all input columns into Parquet +/// Adds a TableWriteNode to write all input columns into CudfHive /// table with compression. /// /// @param outputDirectoryPath Path to a directory to write data to. /// @param fileFormat File format to use for the written data. -/// @param aggregationNode AggregationNode for column statistics collection +/// @param columnStatsSpec ColumnStatsSpec for column statistics collection /// during write. /// @param connectorId Name used to register the connector. /// @param serdeParameters Additional parameters passed to the writer. @@ -87,8 +87,8 @@ std::function cudfTableWrite( std::function cudfTableWrite( const std::string& outputDirectoryPath, const dwio::common::FileFormat fileFormat, - const std::shared_ptr& aggregationNode, - const std::string_view& connectorId = kParquetConnectorId, + const std::optional& columnStatsSpec, + const std::string_view& connectorId = kCudfHiveConnectorId, const std::unordered_map& serdeParameters = {}, const std::shared_ptr& options = nullptr, const std::string& outputFileName = "", diff --git a/velox/experimental/wave/dwio/nimble/tests/NimbleReaderTest.cpp b/velox/experimental/wave/dwio/nimble/tests/NimbleReaderTest.cpp index dc80719944fe..bb1119110653 100644 --- a/velox/experimental/wave/dwio/nimble/tests/NimbleReaderTest.cpp +++ b/velox/experimental/wave/dwio/nimble/tests/NimbleReaderTest.cpp @@ -312,7 +312,7 @@ TEST_F(NimbleReaderTest, decodeTrivialSingleLevelFloat) { test({{input}}, readFactors, compressionOptions); } -TEST_F(NimbleReaderTest, TrivialWithCompressionShouldFail) { +TEST_F(NimbleReaderTest, DISABLED_TrivialWithCompressionShouldFail) { using namespace facebook::nimble; auto c0 = makeFlatVector(17, [](auto row) { return row * 1.1; }); diff --git a/velox/experimental/wave/exec/TableScan.cpp b/velox/experimental/wave/exec/TableScan.cpp index 6f09865a76a1..6a1cb58389b3 100644 --- a/velox/experimental/wave/exec/TableScan.cpp +++ b/velox/experimental/wave/exec/TableScan.cpp @@ -90,6 +90,29 @@ void TableScan::updateStats( } } +void TableScan::updateStats( + std::unordered_map connectorStats, + WaveSplitReader* splitReader) { + auto lockedStats = stats().wlock(); + if (splitReader) { + lockedStats->rawInputPositions = splitReader->getCompletedRows(); + lockedStats->rawInputBytes = splitReader->getCompletedBytes(); + } + for (const auto& [name, metric] : connectorStats) { + if (name == "ioWaitNanos") { + ioWaitNanos_ += metric.sum - lastIoWaitNanos_; + lastIoWaitNanos_ = metric.sum; + } + if (UNLIKELY(lockedStats->runtimeStats.count(name) == 0)) { + lockedStats->runtimeStats.insert( + std::make_pair(name, RuntimeMetric(metric.unit))); + } else { + VELOX_CHECK_EQ(lockedStats->runtimeStats.at(name).unit, metric.unit); + } + lockedStats->runtimeStats.at(name).merge(metric); + } +} + BlockingReason TableScan::nextSplit(ContinueFuture* future) { exec::Split split; blockingReason_ = driverCtx_->task->getSplitOrFuture( @@ -106,7 +129,7 @@ BlockingReason TableScan::nextSplit(ContinueFuture* future) { if (!split.hasConnectorSplit()) { noMoreSplits_ = true; if (dataSource_) { - updateStats(dataSource_->runtimeStats()); + updateStats(dataSource_->getRuntimeStats()); } return BlockingReason::kNotBlocked; } diff --git a/velox/experimental/wave/exec/TableScan.h b/velox/experimental/wave/exec/TableScan.h index e2481573223e..905cc0f7b80c 100644 --- a/velox/experimental/wave/exec/TableScan.h +++ b/velox/experimental/wave/exec/TableScan.h @@ -96,6 +96,9 @@ class TableScan : public WaveSourceOperator { void updateStats( std::unordered_map stats, WaveSplitReader* splitReader = nullptr); + void updateStats( + std::unordered_map stats, + WaveSplitReader* splitReader = nullptr); // Process-wide IO wait time. static std::atomic ioWaitNanos_; diff --git a/velox/experimental/wave/exec/WaveDataSource.h b/velox/experimental/wave/exec/WaveDataSource.h index caf7c28f2504..be0cf309cd95 100644 --- a/velox/experimental/wave/exec/WaveDataSource.h +++ b/velox/experimental/wave/exec/WaveDataSource.h @@ -56,7 +56,7 @@ class WaveDataSource : public std::enable_shared_from_this { virtual uint64_t getCompletedRows() = 0; - virtual std::unordered_map runtimeStats() = 0; + virtual std::unordered_map getRuntimeStats() = 0; virtual void setFromDataSource(std::shared_ptr source) { VELOX_UNSUPPORTED(); diff --git a/velox/experimental/wave/exec/WaveHiveDataSource.cpp b/velox/experimental/wave/exec/WaveHiveDataSource.cpp index 5bbbfcdc1b80..8dea8094458b 100644 --- a/velox/experimental/wave/exec/WaveHiveDataSource.cpp +++ b/velox/experimental/wave/exec/WaveHiveDataSource.cpp @@ -140,11 +140,12 @@ uint64_t WaveHiveDataSource::getCompletedRows() { return completedRows_; } -std::unordered_map -WaveHiveDataSource::runtimeStats() { - auto map = runtimeStats_.toMap(); +std::unordered_map +WaveHiveDataSource::getRuntimeStats() { + auto map = runtimeStats_.toRuntimeMetricMap(); for (const auto& [name, counter] : splitReaderStats_) { - map.insert(std::make_pair(name, counter)); + map.insert( + std::make_pair(name, RuntimeMetric(counter.value, counter.unit))); } return map; } diff --git a/velox/experimental/wave/exec/WaveHiveDataSource.h b/velox/experimental/wave/exec/WaveHiveDataSource.h index 9f994d459cce..01634f701d1a 100644 --- a/velox/experimental/wave/exec/WaveHiveDataSource.h +++ b/velox/experimental/wave/exec/WaveHiveDataSource.h @@ -59,7 +59,7 @@ class WaveHiveDataSource : public WaveDataSource { uint64_t getCompletedRows() override; - std::unordered_map runtimeStats() override; + std::unordered_map getRuntimeStats() override; static void registerConnector(); diff --git a/velox/expression/CMakeLists.txt b/velox/expression/CMakeLists.txt index bf2b0d338805..626659030e48 100644 --- a/velox/expression/CMakeLists.txt +++ b/velox/expression/CMakeLists.txt @@ -42,6 +42,7 @@ velox_add_library( EvalCtx.cpp Expr.cpp ExprCompiler.cpp + ExprRewriteRegistry.cpp ExprToSubfieldFilter.cpp ExprUtils.cpp FieldReference.cpp diff --git a/velox/expression/CastExpr.cpp b/velox/expression/CastExpr.cpp index 99d3b3005896..2819481d27b4 100644 --- a/velox/expression/CastExpr.cpp +++ b/velox/expression/CastExpr.cpp @@ -24,6 +24,7 @@ #include "velox/expression/PeeledEncoding.h" #include "velox/expression/PrestoCastHooks.h" #include "velox/expression/ScopedVarSetter.h" +#include "velox/external/tzdb/time_zone.h" #include "velox/functions/lib/RowsTranslationUtil.h" #include "velox/type/Type.h" #include "velox/type/tz/TimeZoneMap.h" @@ -213,6 +214,174 @@ VectorPtr CastExpr::castFromIntervalDayTime( } } +VectorPtr CastExpr::castFromTime( + const SelectivityVector& rows, + const BaseVector& input, + exec::EvalCtx& context, + const TypePtr& toType) { + VectorPtr castResult; + context.ensureWritable(rows, toType, castResult); + (*castResult).clearNulls(rows); + + auto* inputFlatVector = input.as>(); + switch (toType->kind()) { + case TypeKind::VARCHAR: { + // Get session timezone + const auto* timeZone = + getTimeZoneFromConfig(context.execCtx()->queryCtx()->queryConfig()); + // Get session start time + const auto startTimeMs = + context.execCtx()->queryCtx()->queryConfig().sessionStartTimeMs(); + auto systemDay = std::chrono::milliseconds{startTimeMs} / kMillisInDay; + + auto* resultFlatVector = castResult->as>(); + + Buffer* buffer = resultFlatVector->getBufferWithSpace( + rows.countSelected() * TimeType::kTimeToVarcharRowSize, + true /*exactSize*/); + char* rawBuffer = buffer->asMutable() + buffer->size(); + + applyToSelectedNoThrowLocal(context, rows, castResult, [&](int row) { + try { + // Use timezone-aware conversion + auto systemTime = + systemDay.count() * kMillisInDay + inputFlatVector->valueAt(row); + + int64_t adjustedTime{0}; + if (timeZone) { + adjustedTime = + (timeZone->to_local(std::chrono::milliseconds{systemTime}) % + kMillisInDay) + .count(); + } else { + adjustedTime = systemTime % kMillisInDay; + } + + if (adjustedTime < 0) { + adjustedTime += kMillisInDay; + } + + auto output = TIME()->valueToString(adjustedTime, rawBuffer); + resultFlatVector->setNoCopy(row, output); + rawBuffer += output.size(); + } catch (const VeloxException& ue) { + if (!ue.isUserError()) { + throw; + } + VELOX_USER_FAIL( + makeErrorMessage(input, row, toType) + " " + ue.message()); + } catch (const std::exception& e) { + VELOX_USER_FAIL( + makeErrorMessage(input, row, toType) + " " + e.what()); + } + }); + + buffer->setSize(rawBuffer - buffer->asMutable()); + return castResult; + } + case TypeKind::BIGINT: { + // if input is constant, create a constant output vector + if (input.isConstantEncoding()) { + auto constantInput = input.as>(); + if (constantInput->isNullAt(0)) { + return BaseVector::createNullConstant( + toType, rows.end(), context.pool()); + } else { + auto constantValue = constantInput->valueAt(0); + return std::make_shared>( + context.pool(), + rows.end(), + false, // isNull + toType, + std::move(constantValue)); + } + } + + // fallback to element-wise copy for non-constant inputs + auto* resultFlatVector = castResult->as>(); + applyToSelectedNoThrowLocal(context, rows, castResult, [&](int row) { + resultFlatVector->set(row, inputFlatVector->valueAt(row)); + }); + return castResult; + } + case TypeKind::TIMESTAMP: { + // if input is constant, create a constant output vector + if (input.isConstantEncoding()) { + auto constantInput = input.as>(); + if (constantInput->isNullAt(0)) { + return BaseVector::createNullConstant( + toType, rows.end(), context.pool()); + } else { + auto timeMillis = constantInput->valueAt(0); + return std::make_shared>( + context.pool(), + rows.end(), + false, // isNull + toType, + Timestamp::fromMillis(timeMillis)); + } + } + + // fallback to element-wise copy for non-constant inputs + auto* resultFlatVector = castResult->as>(); + applyToSelectedNoThrowLocal(context, rows, castResult, [&](int row) { + auto timeMillis = inputFlatVector->valueAt(row); + resultFlatVector->set(row, Timestamp::fromMillis(timeMillis)); + }); + return castResult; + } + default: + VELOX_UNSUPPORTED( + "Cast from TIME to {} is not supported", toType->toString()); + } +} + +VectorPtr CastExpr::castToTime( + const SelectivityVector& rows, + const BaseVector& input, + exec::EvalCtx& context, + const TypePtr& fromType) { + switch (fromType->kind()) { + case TypeKind::VARCHAR: { + VectorPtr castResult; + context.ensureWritable(rows, TIME(), castResult); + (*castResult).clearNulls(rows); + + // Get session timezone and start time for timezone conversions + const auto* timeZone = + getTimeZoneFromConfig(context.execCtx()->queryCtx()->queryConfig()); + const auto sessionStartTimeMs = + context.execCtx()->queryCtx()->queryConfig().sessionStartTimeMs(); + + auto* inputVector = input.as>(); + auto* resultFlatVector = castResult->as>(); + + applyToSelectedNoThrowLocal(context, rows, castResult, [&](int row) { + try { + const auto inputString = inputVector->valueAt(row); + int64_t result = + TIME()->valueToTime(inputString, timeZone, sessionStartTimeMs); + resultFlatVector->set(row, result); + } catch (const VeloxException& ue) { + if (!ue.isUserError()) { + throw; + } + VELOX_USER_FAIL( + makeErrorMessage(input, row, TIME()) + " " + ue.message()); + } catch (const std::exception& e) { + VELOX_USER_FAIL( + makeErrorMessage(input, row, TIME()) + " " + e.what()); + } + }); + + return castResult; + } + default: + VELOX_UNSUPPORTED( + "Cast from {} to TIME is not supported", fromType->toString()); + } +} + namespace { void propagateErrorsOrSetNulls( bool setNullInResultAtError, @@ -220,11 +389,11 @@ void propagateErrorsOrSetNulls( const SelectivityVector& nestedRows, const BufferPtr& elementToTopLevelRows, VectorPtr& result, - EvalErrorsPtr& oldErrors) { + exec::EvalErrorsPtr& oldErrors) { if (context.errors()) { if (setNullInResultAtError) { - // Errors in context.errors() should be translated to nulls in the top - // level rows. + // Errors in context.errors() should be translated to nulls in + // the top level rows. context.convertElementErrorsToTopLevelNulls( nestedRows, elementToTopLevelRows, result); } else { @@ -251,8 +420,8 @@ VectorPtr CastExpr::applyMap( exec::EvalCtx& context, const MapType& fromType, const MapType& toType) { - // Cast input keys/values vector to output keys/values vector using their - // element selectivity vector + // Cast input keys/values vector to output keys/values vector using + // their element selectivity vector // Initialize nested rows auto mapKeys = input->mapKeys(); @@ -304,8 +473,8 @@ VectorPtr CastExpr::applyMap( } } - // Returned map vector should be addressable for every element, even those - // that are not selected. + // Returned map vector should be addressable for every element, even + // those that are not selected. BufferPtr sizes = input->sizes(); if (newMapKeys->isConstantEncoding() && newMapValues->isConstantEncoding()) { // We extends size since that is cheap. @@ -353,8 +522,8 @@ VectorPtr CastExpr::applyArray( exec::EvalCtx& context, const ArrayType& fromType, const ArrayType& toType) { - // Cast input array elements to output array elements based on their types - // using their linear selectivity vector + // Cast input array elements to output array elements based on their + // types using their linear selectivity vector auto arrayElements = input->elements(); auto nestedRows = @@ -377,8 +546,8 @@ VectorPtr CastExpr::applyArray( newElements); } - // Returned array vector should be addressable for every element, even those - // that are not selected. + // Returned array vector should be addressable for every element, + // even those that are not selected. BufferPtr sizes = input->sizes(); if (newElements->isConstantEncoding()) { // If the newElements we extends its size since that is cheap. @@ -423,8 +592,8 @@ VectorPtr CastExpr::applyRow( int numInputChildren = input->children().size(); int numOutputChildren = toRowType.size(); - // Extract the flag indicating matching of children must be done by name or - // position + // Extract the flag indicating matching of children must be done by + // name or position auto matchByName = context.execCtx()->queryCtx()->queryConfig().isMatchStructByName(); @@ -434,14 +603,16 @@ VectorPtr CastExpr::applyRow( EvalErrorsPtr oldErrors; if (setNullInResultAtError()) { - // We need to isolate errors that happen during the cast from previous - // errors since those translate to nulls, unlike exisiting errors. + // We need to isolate errors that happen during the cast from + // previous errors since those translate to nulls, unlike + // exisiting errors. context.swapErrors(oldErrors); } for (auto toChildrenIndex = 0; toChildrenIndex < numOutputChildren; toChildrenIndex++) { - // For each child, find the corresponding column index in the output + // For each child, find the corresponding column index in the + // output const auto& toFieldName = toRowType.nameOf(toChildrenIndex); bool matchNotFound = false; @@ -615,8 +786,8 @@ void CastExpr::applyPeeled( }; if (setNullInResultAtError()) { - // This can be optimized by passing setNullInResultAtError() to castTo and - // castFrom operations. + // This can be optimized by passing setNullInResultAtError() to + // castTo and castFrom operations. EvalErrorsPtr oldErrors; context.swapErrors(oldErrors); @@ -650,6 +821,10 @@ void CastExpr::applyPeeled( "Cast from {} to {} is not supported", fromType->toString(), toType->toString()); + } else if (fromType->isTime()) { + result = castFromTime(rows, input, context, toType); + } else if (toType->isTime()) { + result = castToTime(rows, input, context, fromType); } else if (toType->isShortDecimal()) { result = applyDecimal(rows, input, context, fromType, toType); } else if (toType->isLongDecimal()) { @@ -769,8 +944,8 @@ VectorPtr CastExpr::applyTimestampToVarcharCast( const auto stringView = Timestamp::tsToStringView(inputValue, options, rawBuffer); flatResult->setNoCopy(row, stringView); - // The result of both Presto and Spark contains more than 12 digits even - // when 'zeroPaddingYear' is disabled. + // The result of both Presto and Spark contains more than 12 + // digits even when 'zeroPaddingYear' is disabled. VELOX_DCHECK(!stringView.isInline()); rawBuffer += stringView.size(); }); @@ -867,8 +1042,8 @@ void CastExpr::apply( context.moveOrCopyResult(localResult, *remainingRows, result); context.releaseVector(localResult); - // If there are nulls or rows that encountered errors in the input, add nulls - // to the result at the same rows. + // If there are nulls or rows that encountered errors in the input, + // add nulls to the result at the same rows. VELOX_CHECK_NOT_NULL(result); if (rawNulls || context.errors()) { EvalCtx::addNulls( @@ -897,7 +1072,8 @@ void CastExpr::evalSpecialForm( } else { apply(rows, input, context, fromType, toType, result); } - // Return 'input' back to the vector pool in 'context' so it can be reused. + // Return 'input' back to the vector pool in 'context' so it can be + // reused. context.releaseVector(input); } diff --git a/velox/expression/CastExpr.h b/velox/expression/CastExpr.h index 5ae6d9520188..09c1c03baff5 100644 --- a/velox/expression/CastExpr.h +++ b/velox/expression/CastExpr.h @@ -204,6 +204,18 @@ class CastExpr : public SpecialForm { exec::EvalCtx& context, const TypePtr& toType); + VectorPtr castFromTime( + const SelectivityVector& rows, + const BaseVector& input, + exec::EvalCtx& context, + const TypePtr& toType); + + VectorPtr castToTime( + const SelectivityVector& rows, + const BaseVector& input, + exec::EvalCtx& context, + const TypePtr& fromType); + template void applyDecimalCastKernel( const SelectivityVector& rows, diff --git a/velox/expression/EvalCtx.cpp b/velox/expression/EvalCtx.cpp index 8a2ac21fae53..ac6cc68aea42 100644 --- a/velox/expression/EvalCtx.cpp +++ b/velox/expression/EvalCtx.cpp @@ -192,6 +192,25 @@ void EvalCtx::setStatus(vector_size_t index, const Status& status) { } } +void EvalCtx::setStatuses(const SelectivityVector& rows, const Status& status) { + VELOX_CHECK(!status.ok(), "Status must be an error"); + if (status.isUserError()) { + if (throwOnError_) { + VELOX_USER_FAIL(status.message()); + } + + if (captureErrorDetails_) { + auto veloxException = toVeloxUserError(status.message()); + rows.applyToSelected( + [&](auto row) { addError(row, veloxException, errors_); }); + } else { + rows.applyToSelected([&](auto row) { addError(row, errors_); }); + } + } else { + VELOX_FAIL(status.message()); + } +} + void EvalCtx::setError( vector_size_t index, const std::exception_ptr& exceptionPtr) { diff --git a/velox/expression/EvalCtx.h b/velox/expression/EvalCtx.h index cab284626d15..e544a2e84bcf 100644 --- a/velox/expression/EvalCtx.h +++ b/velox/expression/EvalCtx.h @@ -266,6 +266,7 @@ class EvalCtx { // @param status Must indicate an error. Cannot be "ok". void setStatus(vector_size_t index, const Status& status); + void setStatuses(const SelectivityVector& rows, const Status& status); // If exceptionPtr is known to be a VeloxException use setVeloxExceptionError // instead. @@ -782,40 +783,40 @@ class LocalSelectivityVector { class LocalDecodedVector { public: - explicit LocalDecodedVector(core::ExecCtx& context) : context_(context) {} + explicit LocalDecodedVector(core::ExecCtx& context) : context_(&context) {} - explicit LocalDecodedVector(EvalCtx& context) - : context_(*context.execCtx()) {} + explicit LocalDecodedVector(EvalCtx& evalCtx) : context_(evalCtx.execCtx()) {} - explicit LocalDecodedVector(EvalCtx* context) - : LocalDecodedVector(*context) {} + explicit LocalDecodedVector(EvalCtx* evalCtx) + : context_(evalCtx ? evalCtx->execCtx() : nullptr) {} LocalDecodedVector( const EvalCtx& context, const BaseVector& vector, const SelectivityVector& rows, bool loadLazy = true) - : context_(*context.execCtx()) { + : context_(context.execCtx()) { get()->decode(vector, rows, loadLazy); } LocalDecodedVector(LocalDecodedVector&& other) noexcept : context_{other.context_}, vector_{std::move(other.vector_)} {} - void operator=(LocalDecodedVector&& other) { + void operator=(LocalDecodedVector&& other) noexcept { context_ = other.context_; vector_ = std::move(other.vector_); } ~LocalDecodedVector() { - if (vector_) { - context_.get().releaseDecodedVector(std::move(vector_)); + if (vector_ && context_) { + context_->releaseDecodedVector(std::move(vector_)); } } DecodedVector* get() { if (!vector_) { - vector_ = context_.get().getDecodedVector(); + vector_ = context_ ? context_->getDecodedVector() + : std::make_unique(); } return vector_.get(); } @@ -842,7 +843,7 @@ class LocalDecodedVector { } private: - std::reference_wrapper context_; + core::ExecCtx* context_; std::unique_ptr vector_; }; diff --git a/velox/expression/ExprCompiler.cpp b/velox/expression/ExprCompiler.cpp index b214261933ef..f75b755d6f0a 100644 --- a/velox/expression/ExprCompiler.cpp +++ b/velox/expression/ExprCompiler.cpp @@ -18,13 +18,13 @@ #include "velox/expression/ConstantExpr.h" #include "velox/expression/Expr.h" #include "velox/expression/ExprConstants.h" +#include "velox/expression/ExprRewriteRegistry.h" #include "velox/expression/ExprUtils.h" #include "velox/expression/FieldReference.h" #include "velox/expression/LambdaExpr.h" #include "velox/expression/RowConstructor.h" #include "velox/expression/SimpleFunctionRegistry.h" #include "velox/expression/SpecialFormRegistry.h" -#include "velox/expression/VectorFunction.h" namespace facebook::velox::exec { @@ -105,9 +105,21 @@ std::optional shouldFlatten( return std::nullopt; } -ExprPtr getAlreadyCompiled(const ITypedExpr* expr, ExprDedupMap* visited) { +ExprPtr getAlreadyCompiled( + const ITypedExpr* expr, + const core::QueryConfig& config, + ExprDedupMap* visited) { auto iter = visited->find(expr); - return iter == visited->end() ? nullptr : iter->second; + if (iter == visited->end()) { + return nullptr; + } + + const ExprPtr& alreadyCompiled = iter->second; + if (alreadyCompiled->isDeterministic()) { + return alreadyCompiled; + } + + return config.exprDedupNonDeterministic() ? alreadyCompiled : nullptr; } ExprPtr compileExpression( @@ -226,7 +238,7 @@ std::shared_ptr compileLambda( captureReferences.reserve(lambdaScope.capture.size()); for (auto i = 0; i < lambdaScope.capture.size(); ++i) { auto expr = lambdaScope.captureFieldAccesses[i]; - auto reference = getAlreadyCompiled(expr, &scope->visited); + auto reference = getAlreadyCompiled(expr, config, &scope->visited); if (!reference) { auto inner = lambdaScope.captureReferences[i]; reference = std::make_shared( @@ -305,15 +317,6 @@ std::vector getConstantInputs(const std::vector& exprs) { return constants; } -core::TypedExprPtr rewriteExpression(const core::TypedExprPtr& expr) { - for (auto& rewrite : expressionRewrites()) { - if (auto rewritten = rewrite(expr)) { - return rewritten; - } - } - return expr; -} - ExprPtr compileCall( const TypedExprPtr& expr, std::vector inputs, @@ -422,7 +425,8 @@ ExprPtr compileRewrittenExpression( memory::MemoryPool* pool, const std::unordered_set& flatteningCandidates, bool enableConstantFolding) { - ExprPtr alreadyCompiled = getAlreadyCompiled(expr.get(), &scope->visited); + ExprPtr alreadyCompiled = + getAlreadyCompiled(expr.get(), config, &scope->visited); if (alreadyCompiled) { if (!alreadyCompiled->isMultiplyReferenced()) { scope->exprSet->addToReset(alreadyCompiled); @@ -446,7 +450,7 @@ ExprPtr compileRewrittenExpression( case core::ExprKind::kConcat: { result = getSpecialForm( config, - RowConstructorCallToSpecialForm::kRowConstructor, + expression::kRowConstructor, resultType, std::move(compiledInputs), trackCpuUsage); @@ -527,7 +531,7 @@ ExprPtr compileExpression( memory::MemoryPool* pool, const std::unordered_set& flatteningCandidates, bool enableConstantFolding) { - auto rewritten = rewriteExpression(expr); + auto rewritten = expression::ExprRewriteRegistry::instance().rewrite(expr); if (rewritten.get() != expr.get()) { scope->rewrittenExpressions.push_back(rewritten); } diff --git a/velox/expression/ExprConstants.h b/velox/expression/ExprConstants.h index 7cdf322d93d3..02f0b46156aa 100644 --- a/velox/expression/ExprConstants.h +++ b/velox/expression/ExprConstants.h @@ -17,14 +17,14 @@ namespace facebook::velox::expression { -constexpr const char* kAnd = "and"; -constexpr const char* kOr = "or"; -constexpr const char* kSwitch = "switch"; -constexpr const char* kIf = "if"; -constexpr const char* kFail = "fail"; -constexpr const char* kCoalesce = "coalesce"; -constexpr const char* kCast = "cast"; -constexpr const char* kTryCast = "try_cast"; -constexpr const char* kTry = "try"; +inline constexpr const char* kAnd = "and"; +inline constexpr const char* kOr = "or"; +inline constexpr const char* kSwitch = "switch"; +inline constexpr const char* kIf = "if"; +inline constexpr const char* kCoalesce = "coalesce"; +inline constexpr const char* kCast = "cast"; +inline constexpr const char* kTryCast = "try_cast"; +inline constexpr const char* kTry = "try"; +inline constexpr const char* kRowConstructor = "row_constructor"; } // namespace facebook::velox::expression diff --git a/velox/expression/ExprRewriteRegistry.cpp b/velox/expression/ExprRewriteRegistry.cpp new file mode 100644 index 000000000000..5a387f26a4c5 --- /dev/null +++ b/velox/expression/ExprRewriteRegistry.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/expression/ExprRewriteRegistry.h" +#include "velox/expression/FunctionSignature.h" + +namespace facebook::velox::expression { + +void ExprRewriteRegistry::registerRewrite(ExpressionRewrite rewrite) { + registry_.withWLock([&](auto& list) { list.push_back(std::move(rewrite)); }); +} + +void ExprRewriteRegistry::clear() { + registry_.withWLock([&](auto& list) { list.clear(); }); +} + +core::TypedExprPtr ExprRewriteRegistry::rewrite( + const core::TypedExprPtr& expr) { + core::TypedExprPtr result = expr; + registry_.withRLock([&](const auto& list) { + for (const auto& rewrite : list) { + VELOX_CHECK_NOT_NULL(rewrite); + if (auto rewritten = (rewrite)(expr)) { + result = rewritten; + break; + } + } + }); + + return result; +} +} // namespace facebook::velox::expression diff --git a/velox/expression/ExprRewriteRegistry.h b/velox/expression/ExprRewriteRegistry.h new file mode 100644 index 000000000000..9b49403bb982 --- /dev/null +++ b/velox/expression/ExprRewriteRegistry.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/core/Expressions.h" + +namespace facebook::velox::expression { + +/// An expression re-writer that takes an expression and returns an equivalent +/// expression or nullptr if re-write is not possible. +using ExpressionRewrite = + std::function; + +class ExprRewriteRegistry { + public: + /// Appends a 'rewrite' to 'expressionRewrites'. + /// + /// The logic that applies re-writes is very simple and assumes that all + /// rewrites are independent. For each expression, rewrites are applied in the + /// order they were registered. The first rewrite that returns non-null result + /// terminates the re-write for that particular expression. + void registerRewrite(ExpressionRewrite rewrite); + + /// Clears the registry to remove all registered rewrites. + void clear(); + + core::TypedExprPtr rewrite(const core::TypedExprPtr& expr); + + static ExprRewriteRegistry& instance() { + static ExprRewriteRegistry kInstance; + return kInstance; + } + + private: + folly::Synchronized> registry_; +}; +} // namespace facebook::velox::expression diff --git a/velox/expression/ExprStats.h b/velox/expression/ExprStats.h index 2d285b33ac14..532925cbff25 100644 --- a/velox/expression/ExprStats.h +++ b/velox/expression/ExprStats.h @@ -35,6 +35,8 @@ struct ExprStats { /// evaluation of rows. bool defaultNullRowsSkipped{false}; + auto operator<=>(const ExprStats&) const = default; + void add(const ExprStats& other) { timing.add(other.timing); numProcessedRows += other.numProcessedRows; diff --git a/velox/expression/FunctionSignature.cpp b/velox/expression/FunctionSignature.cpp index ff223cdabbce..8da3b784b097 100644 --- a/velox/expression/FunctionSignature.cpp +++ b/velox/expression/FunctionSignature.cpp @@ -106,6 +106,12 @@ void validateBaseTypeAndCollectTypeParams( const TypeSignature& arg, std::unordered_set& collectedTypeVariables, bool isReturnType) { + if (isReturnType) { + VELOX_USER_CHECK( + !arg.isHomogeneousRow(), + "Homogeneous row cannot appear in return type"); + } + if (!variables.count(arg.baseName())) { auto typeName = boost::algorithm::to_upper_copy(arg.baseName()); diff --git a/velox/expression/RegisterSpecialForm.cpp b/velox/expression/RegisterSpecialForm.cpp index bc20675c63dc..f9a64ee3b373 100644 --- a/velox/expression/RegisterSpecialForm.cpp +++ b/velox/expression/RegisterSpecialForm.cpp @@ -28,6 +28,7 @@ #include "velox/expression/TryExpr.h" namespace facebook::velox::exec { + void registerFunctionCallToSpecialForms() { registerFunctionCallToSpecialForm( expression::kAnd, @@ -48,7 +49,8 @@ void registerFunctionCallToSpecialForms() { registerFunctionCallToSpecialForm( expression::kTry, std::make_unique()); registerFunctionCallToSpecialForm( - RowConstructorCallToSpecialForm::kRowConstructor, + expression::kRowConstructor, std::make_unique()); } + } // namespace facebook::velox::exec diff --git a/velox/expression/RegisterSpecialForm.h b/velox/expression/RegisterSpecialForm.h index bd71ac53b83a..90e57dee3e94 100644 --- a/velox/expression/RegisterSpecialForm.h +++ b/velox/expression/RegisterSpecialForm.h @@ -17,5 +17,7 @@ #pragma once namespace facebook::velox::exec { + void registerFunctionCallToSpecialForms(); -} + +} // namespace facebook::velox::exec diff --git a/velox/expression/RowConstructor.cpp b/velox/expression/RowConstructor.cpp index 59f3f2b42709..61443e5acd54 100644 --- a/velox/expression/RowConstructor.cpp +++ b/velox/expression/RowConstructor.cpp @@ -15,6 +15,7 @@ */ #include "velox/expression/RowConstructor.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/VectorFunction.h" namespace facebook::velox::exec { @@ -40,13 +41,15 @@ ExprPtr RowConstructorCallToSpecialForm::constructSpecialForm( [&config](auto& functionMap) -> std::pair< std::shared_ptr, VectorFunctionMetadata> { - auto functionIterator = functionMap.find(kRowConstructor); + auto functionIterator = functionMap.find(expression::kRowConstructor); if (functionIterator != functionMap.end()) { return { - functionIterator->second.factory(kRowConstructor, {}, config), + functionIterator->second.factory( + expression::kRowConstructor, {}, config), functionIterator->second.metadata}; } else { - VELOX_FAIL("Function {} is not registered.", kRowConstructor); + VELOX_FAIL( + "Function {} is not registered.", expression::kRowConstructor); } }); @@ -55,7 +58,7 @@ ExprPtr RowConstructorCallToSpecialForm::constructSpecialForm( std::move(compiledChildren), function, metadata, - kRowConstructor, + expression::kRowConstructor, trackCpuUsage); } } // namespace facebook::velox::exec diff --git a/velox/expression/RowConstructor.h b/velox/expression/RowConstructor.h index 9e29208536cc..dd84d26ca973 100644 --- a/velox/expression/RowConstructor.h +++ b/velox/expression/RowConstructor.h @@ -28,7 +28,5 @@ class RowConstructorCallToSpecialForm : public FunctionCallToSpecialForm { std::vector&& compiledChildren, bool trackCpuUsage, const core::QueryConfig& config) override; - - static constexpr const char* kRowConstructor = "row_constructor"; }; } // namespace facebook::velox::exec diff --git a/velox/expression/SignatureBinder.cpp b/velox/expression/SignatureBinder.cpp index c73b75c00d74..54a29cebe362 100644 --- a/velox/expression/SignatureBinder.cpp +++ b/velox/expression/SignatureBinder.cpp @@ -19,6 +19,7 @@ #include "velox/expression/SignatureBinder.h" #include "velox/expression/type_calculation/TypeCalculation.h" #include "velox/type/Type.h" +#include "velox/type/TypeUtil.h" namespace facebook::velox::exec { namespace { @@ -301,6 +302,45 @@ bool SignatureBinderBase::tryBind( } const auto& params = typeSignature.parameters(); + + // Handle homogeneous row case: row(T, ...) + if (typeSignature.isHomogeneousRow()) { + VELOX_CHECK_EQ( + params.size(), 1, "Homogeneous row must have exactly one parameter"); + + if (actualType->kind() != TypeKind::ROW) { + return false; + } + + if (actualType->size() == 0) { + // Empty row is always compatible with homogeneous row. + return true; + } + + // All children must unify to the same type variable T + const auto& typeParam = params[0]; + const auto& paramBaseName = typeParam.baseName(); + + // First, check and extract the common child type if homogeneous. + const auto actualChildType = + velox::type::tryGetHomogeneousRowChild(actualType); + if (!actualChildType) { + return false; + } + + if (variables().count(paramBaseName)) { + auto it = typeVariablesBindings_.find(paramBaseName); + if (it != typeVariablesBindings_.end()) { + return it->second->equivalent(*actualChildType); + } else { + typeVariablesBindings_[paramBaseName] = actualChildType; + return true; + } + } else { + return tryBind(typeParam, actualChildType); + } + } + // Type Parameters can recurse. if (params.size() != actualType->parameters().size()) { return false; diff --git a/velox/expression/SwitchExpr.cpp b/velox/expression/SwitchExpr.cpp index 835092cf4b5f..9baff240ce33 100644 --- a/velox/expression/SwitchExpr.cpp +++ b/velox/expression/SwitchExpr.cpp @@ -16,6 +16,7 @@ #include "velox/expression/SwitchExpr.h" #include "velox/expression/BooleanMix.h" #include "velox/expression/ConstantExpr.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/FieldReference.h" #include "velox/expression/ScopedVarSetter.h" @@ -35,7 +36,7 @@ SwitchExpr::SwitchExpr( SpecialFormKind::kSwitch, std::move(type), inputs, - "switch", + expression::kSwitch, hasElseClause(inputs) && inputsSupportFlatNoNullsFastPath, false /* trackCpuUsage */), numCases_{inputs_.size() / 2}, diff --git a/velox/expression/TypeSignature.cpp b/velox/expression/TypeSignature.cpp index ba5385fb9142..422e3b68d8e0 100644 --- a/velox/expression/TypeSignature.cpp +++ b/velox/expression/TypeSignature.cpp @@ -30,7 +30,11 @@ std::string TypeSignature::toString() const { } out << baseName_; if (!parameters_.empty()) { - out << "(" << folly::join(",", parameters_) << ")"; + if (isHomogeneousRow()) { + out << "(" << parameters_[0].toString() << ", ..." << ")"; + } else { + out << "(" << folly::join(",", parameters_) << ")"; + } } return out.str(); } diff --git a/velox/expression/TypeSignature.h b/velox/expression/TypeSignature.h index a8f66eb41fed..13ae2bfad49e 100644 --- a/velox/expression/TypeSignature.h +++ b/velox/expression/TypeSignature.h @@ -33,13 +33,16 @@ class TypeSignature { /// @param rowFieldName if this type signature is a field of another parent /// row type, it can optionally have a name. E.g. `row(id bigint)` would have /// "id" set as rowFieldName in the "bigint" parameter. + /// @param variableArity indicates if the last parameter is variadic. TypeSignature( std::string baseName, std::vector parameters, - std::optional rowFieldName = std::nullopt) + std::optional rowFieldName = std::nullopt, + bool variableArity = false) : baseName_{std::move(baseName)}, parameters_{std::move(parameters)}, - rowFieldName_(std::move(rowFieldName)) {} + rowFieldName_(std::move(rowFieldName)), + variableArity_{variableArity} {} const std::string& baseName() const { return baseName_; @@ -53,11 +56,20 @@ class TypeSignature { return rowFieldName_; } + bool variableArity() const { + return variableArity_; + } + + bool isHomogeneousRow() const { + return baseName_ == "row" && parameters_.size() == 1 && variableArity_; + } + std::string toString() const; bool operator==(const TypeSignature& rhs) const { return baseName_ == rhs.baseName_ && parameters_ == rhs.parameters_ && - rowFieldName_ == rhs.rowFieldName_; + rowFieldName_ == rhs.rowFieldName_ && + variableArity_ == rhs.variableArity_; } private: @@ -67,6 +79,9 @@ class TypeSignature { // If this object is a field of another parent row type, it can optionally // have a name, e.g, `row(id bigint)` const std::optional rowFieldName_; + + // Indicates if the parameter is variadic. + const bool variableArity_; }; using TypeSignaturePtr = std::shared_ptr; diff --git a/velox/expression/UdfTypeResolver.h b/velox/expression/UdfTypeResolver.h index b27c197dbdb4..90472c3ec44c 100644 --- a/velox/expression/UdfTypeResolver.h +++ b/velox/expression/UdfTypeResolver.h @@ -150,6 +150,13 @@ struct resolver { using out_type = int32_t; }; +template <> +struct resolver