diff --git a/velox/experimental/cudf/exec/CudfHashJoin.cpp b/velox/experimental/cudf/exec/CudfHashJoin.cpp index f4e59a168738..99c6385e6dd3 100644 --- a/velox/experimental/cudf/exec/CudfHashJoin.cpp +++ b/velox/experimental/cudf/exec/CudfHashJoin.cpp @@ -376,7 +376,13 @@ CudfHashJoinProbe::CudfHashJoinProbe( } bool CudfHashJoinProbe::needsInput() const { - return !finished_ && input_ == nullptr; + if (cudfDebugEnabled()) { + std::cout << "Calling CudfHashJoinProbe::needsInput" << std::endl; + } + if (joinNode_->isRightJoin()) { + return !noMoreInput_; + } + return !noMoreInput_ && !finished_ && input_ == nullptr; } void CudfHashJoinProbe::addInput(RowVectorPtr input) { @@ -397,7 +403,81 @@ void CudfHashJoinProbe::addInput(RowVectorPtr input) { auto lockedStats = stats_.wlock(); lockedStats->numNullKeys += null_count; } - input_ = std::move(input); + if (!joinNode_->isRightJoin()) { + input_ = std::move(input); + return; + } + + // Queue inputs and process all at once + if (input->size() > 0) { + inputs_.push_back(std::move(cudfInput)); + } +} + +void CudfHashJoinProbe::noMoreInput() { + if (cudfDebugEnabled()) { + std::cout << "Calling CudfHashJoinProbe::noMoreInput" << std::endl; + } + VELOX_NVTX_OPERATOR_FUNC_RANGE(); + Operator::noMoreInput(); + if (!joinNode_->isRightJoin()) { + return; + } + std::vector promises; + std::vector> peers; + // Only last driver collects all answers + if (!operatorCtx_->task()->allPeersFinished( + planNodeId(), operatorCtx_->driver(), &future_, promises, peers)) { + return; + } + // Collect results from peers + for (auto& peer : peers) { + auto op = peer->findOperator(planNodeId()); + auto* probe = dynamic_cast(op); + VELOX_CHECK_NOT_NULL(probe); + inputs_.insert(inputs_.end(), probe->inputs_.begin(), probe->inputs_.end()); + } + + SCOPE_EXIT { + // Realize the promises so that the other Drivers (which were not + // the last to finish) can continue from the barrier and finish. + peers.clear(); + for (auto& promise : promises) { + promise.setValue(); + } + }; + + auto stream = cudfGlobalStreamPool().get_stream(); + std::unique_ptr tbl; + if (inputs_.size() == 0) { + auto emptyRowVector = RowVector::createEmpty( + joinNode_->sources()[1]->outputType(), operatorCtx_->pool()); + tbl = facebook::velox::cudf_velox::with_arrow::toCudfTable( + emptyRowVector, operatorCtx_->pool(), stream); + } else { + tbl = getConcatenatedTable(inputs_, stream); + } + + // Release input data after synchronizing + stream.synchronize(); + + VELOX_CHECK_NOT_NULL(tbl); + + if (cudfDebugEnabled()) { + std::cout << "Probe table number of columns: " << tbl->num_columns() + << std::endl; + std::cout << "Probe table number of rows: " << tbl->num_rows() << std::endl; + } + + // Store the concatenated table in input_ + input_ = std::make_shared( + operatorCtx_->pool(), + joinNode_->outputType(), + tbl->num_rows(), + std::move(tbl), + stream); + + inputs_.clear(); } RowVectorPtr CudfHashJoinProbe::getOutput() { @@ -644,6 +724,14 @@ bool CudfHashJoinProbe::skipProbeOnEmptyBuild() const { } exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) { + if (joinNode_->isRightJoin() && hashObject_.has_value()) { + if (!future_.valid()) { + return exec::BlockingReason::kNotBlocked; + } + *future = std::move(future_); + return exec::BlockingReason::kWaitForJoinProbe; + } + if (hashObject_.has_value()) { return exec::BlockingReason::kNotBlocked; } @@ -679,6 +767,10 @@ exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) { } } } + if (joinNode_->isRightJoin() && future_.valid()) { + *future = std::move(future_); + return exec::BlockingReason::kWaitForJoinProbe; + } return exec::BlockingReason::kNotBlocked; } diff --git a/velox/experimental/cudf/exec/CudfHashJoin.h b/velox/experimental/cudf/exec/CudfHashJoin.h index 6d7955ddb10c..da2862072f8d 100644 --- a/velox/experimental/cudf/exec/CudfHashJoin.h +++ b/velox/experimental/cudf/exec/CudfHashJoin.h @@ -82,6 +82,8 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper { void addInput(RowVectorPtr input) override; + void noMoreInput() override; + RowVectorPtr getOutput() override; bool skipProbeOnEmptyBuild() const; @@ -92,7 +94,8 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper { return joinType == core::JoinType::kInner || joinType == core::JoinType::kLeft || joinType == core::JoinType::kAnti || - joinType == core::JoinType::kLeftSemiFilter; + joinType == core::JoinType::kLeftSemiFilter || + joinType == core::JoinType::kRight; } bool isFinished() override; @@ -107,6 +110,10 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper { bool rightPrecomputed_{false}; + // Batched probe inputs needed for right join + std::vector inputs_; + ContinueFuture future_{ContinueFuture::makeEmpty()}; + std::vector leftKeyIndices_; std::vector rightKeyIndices_; std::vector leftColumnIndicesToGather_;