Skip to content

Commit 5d5fc13

Browse files
duxiao1212facebook-github-bot
authored andcommitted
feat: Impl sort key for LocalShuffleReader (#26620)
Summary: Implement sorted shuffle k-way merge for LocalShuffleReader, when it's sortedShuffle. Added k-way merge support using TreeOfLosers to efficiently merge multiple sorted shuffle files. The reader streams data from sorted files and returns merged results in sorted order. Reviewed By: tanjialiang Differential Revision: D86888221
1 parent a87fdb9 commit 5d5fc13

File tree

7 files changed

+430
-47
lines changed

7 files changed

+430
-47
lines changed

presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp

Lines changed: 198 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,102 @@
1515
#include "presto_cpp/external/json/nlohmann/json.hpp"
1616
#include "presto_cpp/main/common/Configs.h"
1717

18-
#include <folly/lang/Bits.h>
18+
#include "velox/common/file/FileInputStream.h"
1919

20-
using namespace facebook::velox::exec;
21-
using namespace facebook::velox;
20+
#include <boost/range/algorithm/sort.hpp>
2221

2322
namespace facebook::presto::operators {
2423

2524
using json = nlohmann::json;
2625

2726
namespace {
2827

28+
using TStreamIdx = uint16_t;
29+
30+
/// SortedFileInputStream reads sorted (key, data) pairs from a single
31+
/// shuffle file with buffered I/O. It extends FileInputStream for efficient
32+
/// buffered I/O and implements MergeStream interface for k-way merge.
33+
class SortedFileInputStream final : public velox::common::FileInputStream,
34+
public velox::MergeStream {
35+
public:
36+
SortedFileInputStream(
37+
const std::string& filePath,
38+
TStreamIdx streamIdx,
39+
velox::memory::MemoryPool* pool,
40+
size_t bufferSize = kDefaultInputStreamBufferSize)
41+
: velox::common::FileInputStream(
42+
velox::filesystems::getFileSystem(filePath, nullptr)
43+
->openFileForRead(filePath),
44+
bufferSize,
45+
pool),
46+
streamIdx_(streamIdx) {
47+
next();
48+
}
49+
50+
~SortedFileInputStream() override = default;
51+
52+
bool next() {
53+
if (atEnd()) {
54+
currentKey_ = {};
55+
currentData_ = {};
56+
keyStorage_.clear();
57+
dataStorage_.clear();
58+
return false;
59+
}
60+
const TRowSize keySize = folly::Endian::big(read<TRowSize>());
61+
const TRowSize dataSize = folly::Endian::big(read<TRowSize>());
62+
63+
currentKey_ = nextStringView(keySize, keyStorage_);
64+
currentData_ = nextStringView(dataSize, dataStorage_);
65+
return true;
66+
}
67+
68+
std::string_view currentKey() const {
69+
return currentKey_;
70+
}
71+
72+
std::string_view currentData() const {
73+
return currentData_;
74+
}
75+
76+
bool hasData() const override {
77+
return !currentData_.empty() || !atEnd();
78+
}
79+
80+
bool operator<(const velox::MergeStream& other) const override {
81+
const auto* otherReader = static_cast<const SortedFileInputStream*>(&other);
82+
if (currentKey_ != otherReader->currentKey_) {
83+
return compareKeys(currentKey_, otherReader->currentKey_);
84+
}
85+
return streamIdx_ < otherReader->streamIdx_;
86+
}
87+
88+
private:
89+
// Returns string_view using zero-copy when data fits in buffer,
90+
// otherwise copies to storage when crossing buffer boundaries.
91+
std::string_view nextStringView(TRowSize size, std::string& storage) {
92+
if (size == 0) {
93+
return {};
94+
}
95+
auto view = nextView(size);
96+
if (view.size() == size) {
97+
return view;
98+
}
99+
storage.resize(size);
100+
std::memcpy(storage.data(), view.data(), view.size());
101+
readBytes(
102+
reinterpret_cast<uint8_t*>(storage.data()) + view.size(),
103+
size - view.size());
104+
return std::string_view(storage);
105+
}
106+
107+
const TStreamIdx streamIdx_;
108+
std::string_view currentKey_;
109+
std::string_view currentData_;
110+
std::string keyStorage_;
111+
std::string dataStorage_;
112+
};
113+
29114
std::vector<RowMetadata>
30115
extractRowMetadata(const char* buffer, size_t bufferSize, bool sortedShuffle) {
31116
std::vector<RowMetadata> rows;
@@ -91,13 +176,9 @@ extractRowMetadata(const char* buffer, size_t bufferSize, bool sortedShuffle) {
91176

92177
inline std::string_view
93178
extractRowData(const RowMetadata& row, const char* buffer, bool sortedShuffle) {
94-
if (sortedShuffle) {
95-
const size_t dataOffset = row.rowStart + (kUint32Size * 2) + row.keySize;
96-
return {buffer + dataOffset, row.dataSize};
97-
} else {
98-
const size_t dataOffset = row.rowStart + kUint32Size;
99-
return {buffer + dataOffset, row.dataSize};
100-
}
179+
const auto dataOffset = row.rowStart +
180+
(sortedShuffle ? (kUint32Size * 2) + row.keySize : kUint32Size);
181+
return {buffer + dataOffset, row.dataSize};
101182
}
102183

103184
std::vector<RowMetadata> extractAndSortRowMetadata(
@@ -106,10 +187,8 @@ std::vector<RowMetadata> extractAndSortRowMetadata(
106187
bool sortedShuffle) {
107188
auto rows = extractRowMetadata(buffer, bufferSize, sortedShuffle);
108189
if (!rows.empty() && sortedShuffle) {
109-
std::sort(
110-
rows.begin(),
111-
rows.end(),
112-
[buffer](const RowMetadata& lhs, const RowMetadata& rhs) {
190+
boost::range::sort(
191+
rows, [buffer](const RowMetadata& lhs, const RowMetadata& rhs) {
113192
const char* lhsKey = buffer + lhs.rowStart + (kUint32Size * 2);
114193
const char* rhsKey = buffer + rhs.rowStart + (kUint32Size * 2);
115194
return compareKeys(
@@ -147,6 +226,7 @@ LocalShuffleWriteInfo LocalShuffleWriteInfo::deserialize(
147226
jsonReadInfo.at("queryId").get_to(shuffleInfo.queryId);
148227
jsonReadInfo.at("shuffleId").get_to(shuffleInfo.shuffleId);
149228
jsonReadInfo.at("numPartitions").get_to(shuffleInfo.numPartitions);
229+
shuffleInfo.sortedShuffle = jsonReadInfo.value("sortedShuffle", false);
150230
return shuffleInfo;
151231
}
152232

@@ -157,6 +237,7 @@ LocalShuffleReadInfo LocalShuffleReadInfo::deserialize(
157237
jsonReadInfo.at("rootPath").get_to(shuffleInfo.rootPath);
158238
jsonReadInfo.at("queryId").get_to(shuffleInfo.queryId);
159239
jsonReadInfo.at("partitionIds").get_to(shuffleInfo.partitionIds);
240+
shuffleInfo.sortedShuffle = jsonReadInfo.value("sortedShuffle", false);
160241
return shuffleInfo;
161242
}
162243

@@ -276,10 +357,11 @@ void LocalShuffleWriter::collect(
276357
sortedShuffle_ || key.empty(),
277358
"key '{}' must be empty for non-sorted shuffle",
278359
key);
360+
279361
const auto rowSize = this->rowSize(key.size(), data.size());
280362
auto& buffer = inProgressPartitions_[partition];
281363
if (buffer == nullptr) {
282-
buffer = AlignedBuffer::allocate<char>(
364+
buffer = velox::AlignedBuffer::allocate<char>(
283365
std::max(static_cast<uint64_t>(rowSize), maxBytesPerPartition_),
284366
pool_,
285367
0);
@@ -319,31 +401,105 @@ LocalShuffleReader::LocalShuffleReader(
319401
fileSystem_ = velox::filesystems::getFileSystem(rootPath_, nullptr);
320402
}
321403

322-
folly::SemiFuture<std::vector<std::unique_ptr<ReadBatch>>>
323-
LocalShuffleReader::next(uint64_t maxBytes) {
324-
if (readPartitionFiles_.empty()) {
325-
readPartitionFiles_ = getReadPartitionFiles();
404+
void LocalShuffleReader::initialize() {
405+
VELOX_CHECK(!initialized_, "LocalShuffleReader already initialized");
406+
readPartitionFiles_ = getReadPartitionFiles();
407+
408+
if (sortedShuffle_ && !readPartitionFiles_.empty()) {
409+
std::vector<std::unique_ptr<velox::MergeStream>> streams;
410+
streams.reserve(readPartitionFiles_.size());
411+
TStreamIdx streamIdx = 0;
412+
for (const auto& filename : readPartitionFiles_) {
413+
VELOX_CHECK(
414+
!filename.empty(),
415+
"Invalid empty shuffle file path for query {}, partitions: [{}]",
416+
queryId_,
417+
folly::join(", ", partitionIds_));
418+
auto reader =
419+
std::make_unique<SortedFileInputStream>(filename, streamIdx, pool_);
420+
if (reader->hasData()) {
421+
streams.push_back(std::move(reader));
422+
++streamIdx;
423+
}
424+
}
425+
if (!streams.empty()) {
426+
merge_ =
427+
std::make_unique<velox::TreeOfLosers<velox::MergeStream, uint16_t>>(
428+
std::move(streams));
429+
}
326430
}
327431

432+
initialized_ = true;
433+
}
434+
435+
std::vector<std::unique_ptr<ReadBatch>> LocalShuffleReader::nextSorted(
436+
uint64_t maxBytes) {
437+
std::vector<std::unique_ptr<ReadBatch>> batches;
438+
439+
if (merge_ == nullptr) {
440+
return batches;
441+
}
442+
443+
auto batchBuffer = velox::AlignedBuffer::allocate<char>(maxBytes, pool_, 0);
444+
std::vector<std::string_view> rows;
445+
uint64_t bufferUsed = 0;
446+
447+
while (auto* stream = merge_->next()) {
448+
auto* reader = dynamic_cast<SortedFileInputStream*>(stream);
449+
const auto data = reader->currentData();
450+
451+
if (bufferUsed + data.size() > maxBytes) {
452+
if (bufferUsed > 0) {
453+
batches.push_back(
454+
std::make_unique<ReadBatch>(
455+
std::move(rows), std::move(batchBuffer)));
456+
return batches;
457+
}
458+
// Single row exceeds buffer - allocate larger buffer
459+
batchBuffer = velox::AlignedBuffer::allocate<char>(data.size(), pool_, 0);
460+
bufferUsed = 0;
461+
}
462+
463+
char* writePos = batchBuffer->asMutable<char>() + bufferUsed;
464+
if (!data.empty()) {
465+
memcpy(writePos, data.data(), data.size());
466+
}
467+
468+
rows.emplace_back(batchBuffer->as<char>() + bufferUsed, data.size());
469+
bufferUsed += data.size();
470+
reader->next();
471+
}
472+
473+
if (!rows.empty()) {
474+
batches.push_back(
475+
std::make_unique<ReadBatch>(std::move(rows), std::move(batchBuffer)));
476+
}
477+
478+
return batches;
479+
}
480+
481+
std::vector<std::unique_ptr<ReadBatch>> LocalShuffleReader::nextUnsorted(
482+
uint64_t maxBytes) {
328483
std::vector<std::unique_ptr<ReadBatch>> batches;
329484
uint64_t totalBytes{0};
330-
// Read files until we reach maxBytes limit or run out of files.
485+
331486
while (readPartitionFileIndex_ < readPartitionFiles_.size()) {
332487
const auto filename = readPartitionFiles_[readPartitionFileIndex_];
333488
auto file = fileSystem_->openFileForRead(filename);
334489
const auto fileSize = file->size();
335490

336-
// Stop if adding this file would exceed maxBytes (unless we haven't read
337-
// any files yet)
491+
// TODO: Refactor to use streaming I/O with bounded buffer size instead of
492+
// loading entire files into memory at once. A streaming approach would
493+
// reduce peak memory consumption and enable processing arbitrarily large
494+
// shuffle files while maintaining constant memory usage.
338495
if (!batches.empty() && totalBytes + fileSize > maxBytes) {
339496
break;
340497
}
341498

342-
auto buffer = AlignedBuffer::allocate<char>(fileSize, pool_, 0);
499+
auto buffer = velox::AlignedBuffer::allocate<char>(fileSize, pool_, 0);
343500
file->pread(0, fileSize, buffer->asMutable<void>());
344501
++readPartitionFileIndex_;
345502

346-
// Parse the buffer to extract individual rows
347503
const char* data = buffer->as<char>();
348504
const auto parsedRows = extractRowMetadata(data, fileSize, sortedShuffle_);
349505
std::vector<std::string_view> rows;
@@ -357,7 +513,17 @@ LocalShuffleReader::next(uint64_t maxBytes) {
357513
std::make_unique<ReadBatch>(std::move(rows), std::move(buffer)));
358514
}
359515

360-
return folly::makeSemiFuture(std::move(batches));
516+
return batches;
517+
}
518+
519+
folly::SemiFuture<std::vector<std::unique_ptr<ReadBatch>>>
520+
LocalShuffleReader::next(uint64_t maxBytes) {
521+
VELOX_CHECK(
522+
initialized_,
523+
"LocalShuffleReader::initialize() must be called before next()");
524+
525+
return folly::makeSemiFuture(
526+
sortedShuffle_ ? nextSorted(maxBytes) : nextUnsorted(maxBytes));
361527
}
362528

363529
void LocalShuffleReader::noMoreData(bool success) {
@@ -403,12 +569,15 @@ std::shared_ptr<ShuffleReader> LocalPersistentShuffleFactory::createReader(
403569
velox::memory::MemoryPool* pool) {
404570
const operators::LocalShuffleReadInfo readInfo =
405571
operators::LocalShuffleReadInfo::deserialize(serializedStr);
406-
return std::make_shared<LocalShuffleReader>(
572+
573+
auto reader = std::make_shared<LocalShuffleReader>(
407574
readInfo.rootPath,
408575
readInfo.queryId,
409576
readInfo.partitionIds,
410-
/*sortShuffle=*/false, // default to false for now
577+
readInfo.sortedShuffle,
411578
pool);
579+
reader->initialize();
580+
return reader;
412581
}
413582

414583
std::shared_ptr<ShuffleWriter> LocalPersistentShuffleFactory::createWriter(
@@ -418,13 +587,14 @@ std::shared_ptr<ShuffleWriter> LocalPersistentShuffleFactory::createWriter(
418587
SystemConfig::instance()->localShuffleMaxPartitionBytes();
419588
const operators::LocalShuffleWriteInfo writeInfo =
420589
operators::LocalShuffleWriteInfo::deserialize(serializedStr);
590+
421591
return std::make_shared<LocalShuffleWriter>(
422592
writeInfo.rootPath,
423593
writeInfo.queryId,
424594
writeInfo.shuffleId,
425595
writeInfo.numPartitions,
426596
maxBytesPerPartition,
427-
/*sortedShuffle=*/false, // default to false for now
597+
writeInfo.sortedShuffle,
428598
pool);
429599
}
430600

@@ -436,5 +606,4 @@ std::vector<RowMetadata> testingExtractRowMetadata(
436606
bool sortedShuffle) {
437607
return extractRowMetadata(buffer, bufferSize, sortedShuffle);
438608
}
439-
440609
} // namespace facebook::presto::operators

0 commit comments

Comments
 (0)