Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 94 additions & 2 deletions velox/experimental/cudf/exec/CudfHashJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,13 @@ CudfHashJoinProbe::CudfHashJoinProbe(
}

bool CudfHashJoinProbe::needsInput() const {
return !finished_ && input_ == nullptr;
if (cudfDebugEnabled()) {
std::cout << "Calling CudfHashJoinProbe::needsInput" << std::endl;
}
if (joinNode_->isRightJoin()) {
return !noMoreInput_;
}
return !noMoreInput_ && !finished_ && input_ == nullptr;
}

void CudfHashJoinProbe::addInput(RowVectorPtr input) {
Expand All @@ -397,7 +403,81 @@ void CudfHashJoinProbe::addInput(RowVectorPtr input) {
auto lockedStats = stats_.wlock();
lockedStats->numNullKeys += null_count;
}
input_ = std::move(input);
if (!joinNode_->isRightJoin()) {
input_ = std::move(input);
return;
}

// Queue inputs and process all at once
if (input->size() > 0) {
inputs_.push_back(std::move(cudfInput));
}
}

void CudfHashJoinProbe::noMoreInput() {
if (cudfDebugEnabled()) {
std::cout << "Calling CudfHashJoinProbe::noMoreInput" << std::endl;
}
VELOX_NVTX_OPERATOR_FUNC_RANGE();
Operator::noMoreInput();
if (!joinNode_->isRightJoin()) {
return;
}
std::vector<ContinuePromise> promises;
std::vector<std::shared_ptr<exec::Driver>> peers;
// Only last driver collects all answers
if (!operatorCtx_->task()->allPeersFinished(
planNodeId(), operatorCtx_->driver(), &future_, promises, peers)) {
return;
}
// Collect results from peers
for (auto& peer : peers) {
auto op = peer->findOperator(planNodeId());
auto* probe = dynamic_cast<CudfHashJoinProbe*>(op);
VELOX_CHECK_NOT_NULL(probe);
inputs_.insert(inputs_.end(), probe->inputs_.begin(), probe->inputs_.end());
}

SCOPE_EXIT {
// Realize the promises so that the other Drivers (which were not
// the last to finish) can continue from the barrier and finish.
peers.clear();
for (auto& promise : promises) {
promise.setValue();
}
};

auto stream = cudfGlobalStreamPool().get_stream();
std::unique_ptr<cudf::table> tbl;
if (inputs_.size() == 0) {
auto emptyRowVector = RowVector::createEmpty(
joinNode_->sources()[1]->outputType(), operatorCtx_->pool());
tbl = facebook::velox::cudf_velox::with_arrow::toCudfTable(
emptyRowVector, operatorCtx_->pool(), stream);
} else {
tbl = getConcatenatedTable(inputs_, stream);
}

// Release input data after synchronizing
stream.synchronize();

VELOX_CHECK_NOT_NULL(tbl);

if (cudfDebugEnabled()) {
std::cout << "Probe table number of columns: " << tbl->num_columns()
<< std::endl;
std::cout << "Probe table number of rows: " << tbl->num_rows() << std::endl;
}

// Store the concatenated table in input_
input_ = std::make_shared<CudfVector>(
operatorCtx_->pool(),
joinNode_->outputType(),
tbl->num_rows(),
std::move(tbl),
stream);

inputs_.clear();
}

RowVectorPtr CudfHashJoinProbe::getOutput() {
Expand Down Expand Up @@ -644,6 +724,14 @@ bool CudfHashJoinProbe::skipProbeOnEmptyBuild() const {
}

exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) {
if (joinNode_->isRightJoin() && hashObject_.has_value()) {
if (!future_.valid()) {
return exec::BlockingReason::kNotBlocked;
}
*future = std::move(future_);
return exec::BlockingReason::kWaitForJoinProbe;
}

if (hashObject_.has_value()) {
return exec::BlockingReason::kNotBlocked;
}
Expand Down Expand Up @@ -679,6 +767,10 @@ exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) {
}
}
}
if (joinNode_->isRightJoin() && future_.valid()) {
*future = std::move(future_);
return exec::BlockingReason::kWaitForJoinProbe;
}
return exec::BlockingReason::kNotBlocked;
}

Expand Down
9 changes: 8 additions & 1 deletion velox/experimental/cudf/exec/CudfHashJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper {

void addInput(RowVectorPtr input) override;

void noMoreInput() override;

RowVectorPtr getOutput() override;

bool skipProbeOnEmptyBuild() const;
Expand All @@ -92,7 +94,8 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper {
return joinType == core::JoinType::kInner ||
joinType == core::JoinType::kLeft ||
joinType == core::JoinType::kAnti ||
joinType == core::JoinType::kLeftSemiFilter;
joinType == core::JoinType::kLeftSemiFilter ||
joinType == core::JoinType::kRight;
}

bool isFinished() override;
Expand All @@ -107,6 +110,10 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper {

bool rightPrecomputed_{false};

// Batched probe inputs needed for right join
std::vector<CudfVectorPtr> inputs_;
ContinueFuture future_{ContinueFuture::makeEmpty()};

std::vector<cudf::size_type> leftKeyIndices_;
std::vector<cudf::size_type> rightKeyIndices_;
std::vector<cudf::size_type> leftColumnIndicesToGather_;
Expand Down
Loading