diff --git a/presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp b/presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp index 2730377edbea8..05b79a00799b6 100644 --- a/presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp +++ b/presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp @@ -15,10 +15,10 @@ #include "presto_cpp/external/json/nlohmann/json.hpp" #include "presto_cpp/main/common/Configs.h" -#include +#include "velox/common/Casts.h" +#include "velox/common/file/FileInputStream.h" -using namespace facebook::velox::exec; -using namespace facebook::velox; +#include namespace facebook::presto::operators { @@ -26,6 +26,85 @@ using json = nlohmann::json; namespace { +using TStreamIdx = uint16_t; + +// Default buffer size for SortedFileInputStream +// This buffer is used for streaming reads from shuffle files during k-way +// merge. +constexpr uint64_t kDefaultInputStreamBufferSize = 8 * 1024 * 1024; // 8MB + +/// SortedFileInputStream reads sorted (key, data) pairs from a single +/// shuffle file with buffered I/O. It extends FileInputStream for efficient +/// buffered I/O and implements MergeStream interface for k-way merge. +class SortedFileInputStream final : public velox::common::FileInputStream, + public velox::MergeStream { + public: + SortedFileInputStream( + const std::string& filePath, + TStreamIdx streamIdx, + velox::memory::MemoryPool* pool, + size_t bufferSize = kDefaultInputStreamBufferSize) + : velox::common::FileInputStream( + velox::filesystems::getFileSystem(filePath, nullptr) + ->openFileForRead(filePath), + bufferSize, + pool), + streamIdx_(streamIdx) { + next(); + } + + ~SortedFileInputStream() override = default; + + bool next() { + if (atEnd()) { + currentKey_.clear(); + currentValue_.clear(); + return false; + } + const TRowSize keySize = folly::Endian::big(read()); + const TRowSize valueSize = folly::Endian::big(read()); + + // TODO: Optimize with zero-copy approach when data is contiguous in buffer. + readString(currentKey_, keySize); + readString(currentValue_, valueSize); + return true; + } + + std::string_view currentKey() const { + return currentKey_; + } + + std::string_view currentValue() const { + return currentValue_; + } + + bool hasData() const override { + return !currentValue_.empty() || !atEnd(); + } + + bool operator<(const velox::MergeStream& other) const override { + const auto* otherReader = static_cast(&other); + if (currentKey_ != otherReader->currentKey_) { + return compareKeys(currentKey_, otherReader->currentKey_); + } + return streamIdx_ < otherReader->streamIdx_; + } + + private: + void readString(std::string& target, TRowSize size) { + if (size > 0) { + target.resize(size); + readBytes(reinterpret_cast(target.data()), size); + } else { + target.clear(); + } + } + + const TStreamIdx streamIdx_; + std::string currentKey_; + std::string currentValue_; +}; + std::vector extractRowMetadata(const char* buffer, size_t bufferSize, bool sortedShuffle) { std::vector rows; @@ -91,13 +170,9 @@ extractRowMetadata(const char* buffer, size_t bufferSize, bool sortedShuffle) { inline std::string_view extractRowData(const RowMetadata& row, const char* buffer, bool sortedShuffle) { - if (sortedShuffle) { - const size_t dataOffset = row.rowStart + (kUint32Size * 2) + row.keySize; - return {buffer + dataOffset, row.dataSize}; - } else { - const size_t dataOffset = row.rowStart + kUint32Size; - return {buffer + dataOffset, row.dataSize}; - } + const auto dataOffset = row.rowStart + + (sortedShuffle ? (kUint32Size * 2) + row.keySize : kUint32Size); + return {buffer + dataOffset, row.dataSize}; } std::vector extractAndSortRowMetadata( @@ -106,10 +181,8 @@ std::vector extractAndSortRowMetadata( bool sortedShuffle) { auto rows = extractRowMetadata(buffer, bufferSize, sortedShuffle); if (!rows.empty() && sortedShuffle) { - std::sort( - rows.begin(), - rows.end(), - [buffer](const RowMetadata& lhs, const RowMetadata& rhs) { + boost::range::sort( + rows, [buffer](const RowMetadata& lhs, const RowMetadata& rhs) { const char* lhsKey = buffer + lhs.rowStart + (kUint32Size * 2); const char* rhsKey = buffer + rhs.rowStart + (kUint32Size * 2); return compareKeys( @@ -147,6 +220,7 @@ LocalShuffleWriteInfo LocalShuffleWriteInfo::deserialize( jsonReadInfo.at("queryId").get_to(shuffleInfo.queryId); jsonReadInfo.at("shuffleId").get_to(shuffleInfo.shuffleId); jsonReadInfo.at("numPartitions").get_to(shuffleInfo.numPartitions); + shuffleInfo.sortedShuffle = jsonReadInfo.value("sortedShuffle", false); return shuffleInfo; } @@ -157,6 +231,7 @@ LocalShuffleReadInfo LocalShuffleReadInfo::deserialize( jsonReadInfo.at("rootPath").get_to(shuffleInfo.rootPath); jsonReadInfo.at("queryId").get_to(shuffleInfo.queryId); jsonReadInfo.at("partitionIds").get_to(shuffleInfo.partitionIds); + shuffleInfo.sortedShuffle = jsonReadInfo.value("sortedShuffle", false); return shuffleInfo; } @@ -276,10 +351,11 @@ void LocalShuffleWriter::collect( sortedShuffle_ || key.empty(), "key '{}' must be empty for non-sorted shuffle", key); + const auto rowSize = this->rowSize(key.size(), data.size()); auto& buffer = inProgressPartitions_[partition]; if (buffer == nullptr) { - buffer = AlignedBuffer::allocate( + buffer = velox::AlignedBuffer::allocate( std::max(static_cast(rowSize), maxBytesPerPartition_), pool_, 0); @@ -319,31 +395,107 @@ LocalShuffleReader::LocalShuffleReader( fileSystem_ = velox::filesystems::getFileSystem(rootPath_, nullptr); } -folly::SemiFuture>> -LocalShuffleReader::next(uint64_t maxBytes) { - if (readPartitionFiles_.empty()) { - readPartitionFiles_ = getReadPartitionFiles(); +void LocalShuffleReader::initialize() { + VELOX_CHECK(!initialized_, "LocalShuffleReader already initialized"); + readPartitionFiles_ = getReadPartitionFiles(); + if (sortedShuffle_ && !readPartitionFiles_.empty()) { + initSortedShuffleRead(); } + initialized_ = true; +} + +void LocalShuffleReader::initSortedShuffleRead() { + std::vector> streams; + streams.reserve(readPartitionFiles_.size()); + TStreamIdx streamIdx = 0; + for (const auto& filename : readPartitionFiles_) { + VELOX_CHECK( + !filename.empty(), + "Invalid empty shuffle file path for query {}, partitions: [{}]", + queryId_, + folly::join(", ", partitionIds_)); + auto reader = + std::make_unique(filename, streamIdx, pool_); + if (reader->hasData()) { + streams.push_back(std::move(reader)); + ++streamIdx; + } + } + if (!streams.empty()) { + merge_ = + std::make_unique>( + std::move(streams)); + } +} + +std::vector> LocalShuffleReader::nextSorted( + uint64_t maxBytes) { + std::vector> batches; + + if (merge_ == nullptr) { + return batches; + } + + auto batchBuffer = velox::AlignedBuffer::allocate(maxBytes, pool_, 0); + std::vector rows; + uint64_t bufferUsed = 0; + + while (auto* stream = merge_->next()) { + auto* reader = velox::checked_pointer_cast(stream); + const auto data = reader->currentValue(); + + if (bufferUsed + data.size() > maxBytes) { + if (bufferUsed > 0) { + batches.push_back( + std::make_unique( + std::move(rows), std::move(batchBuffer))); + return batches; + } + // Single row exceeds buffer - allocate larger buffer + batchBuffer = velox::AlignedBuffer::allocate(data.size(), pool_, 0); + } + + char* writePos = batchBuffer->asMutable() + bufferUsed; + if (!data.empty()) { + memcpy(writePos, data.data(), data.size()); + } + + rows.emplace_back(batchBuffer->as() + bufferUsed, data.size()); + bufferUsed += data.size(); + reader->next(); + } + + if (!rows.empty()) { + batches.push_back( + std::make_unique(std::move(rows), std::move(batchBuffer))); + } + + return batches; +} + +std::vector> LocalShuffleReader::nextUnsorted( + uint64_t maxBytes) { std::vector> batches; uint64_t totalBytes{0}; - // Read files until we reach maxBytes limit or run out of files. + while (readPartitionFileIndex_ < readPartitionFiles_.size()) { const auto filename = readPartitionFiles_[readPartitionFileIndex_]; auto file = fileSystem_->openFileForRead(filename); const auto fileSize = file->size(); - // Stop if adding this file would exceed maxBytes (unless we haven't read - // any files yet) + // TODO: Refactor to use streaming I/O with bounded buffer size instead of + // loading entire files into memory at once. A streaming approach would + // reduce peak memory consumption and enable processing arbitrarily large + // shuffle files while maintaining constant memory usage. if (!batches.empty() && totalBytes + fileSize > maxBytes) { break; } - auto buffer = AlignedBuffer::allocate(fileSize, pool_, 0); + auto buffer = velox::AlignedBuffer::allocate(fileSize, pool_, 0); file->pread(0, fileSize, buffer->asMutable()); ++readPartitionFileIndex_; - // Parse the buffer to extract individual rows const char* data = buffer->as(); const auto parsedRows = extractRowMetadata(data, fileSize, sortedShuffle_); std::vector rows; @@ -357,7 +509,17 @@ LocalShuffleReader::next(uint64_t maxBytes) { std::make_unique(std::move(rows), std::move(buffer))); } - return folly::makeSemiFuture(std::move(batches)); + return batches; +} + +folly::SemiFuture>> +LocalShuffleReader::next(uint64_t maxBytes) { + VELOX_CHECK( + initialized_, + "LocalShuffleReader::initialize() must be called before next()"); + + return folly::makeSemiFuture( + sortedShuffle_ ? nextSorted(maxBytes) : nextUnsorted(maxBytes)); } void LocalShuffleReader::noMoreData(bool success) { @@ -403,12 +565,15 @@ std::shared_ptr LocalPersistentShuffleFactory::createReader( velox::memory::MemoryPool* pool) { const operators::LocalShuffleReadInfo readInfo = operators::LocalShuffleReadInfo::deserialize(serializedStr); - return std::make_shared( + + auto reader = std::make_shared( readInfo.rootPath, readInfo.queryId, readInfo.partitionIds, - /*sortShuffle=*/false, // default to false for now + readInfo.sortedShuffle, pool); + reader->initialize(); + return reader; } std::shared_ptr LocalPersistentShuffleFactory::createWriter( @@ -418,13 +583,14 @@ std::shared_ptr LocalPersistentShuffleFactory::createWriter( SystemConfig::instance()->localShuffleMaxPartitionBytes(); const operators::LocalShuffleWriteInfo writeInfo = operators::LocalShuffleWriteInfo::deserialize(serializedStr); + return std::make_shared( writeInfo.rootPath, writeInfo.queryId, writeInfo.shuffleId, writeInfo.numPartitions, maxBytesPerPartition, - /*sortedShuffle=*/false, // default to false for now + writeInfo.sortedShuffle, pool); } @@ -436,5 +602,4 @@ std::vector testingExtractRowMetadata( bool sortedShuffle) { return extractRowMetadata(buffer, bufferSize, sortedShuffle); } - } // namespace facebook::presto::operators diff --git a/presto-native-execution/presto_cpp/main/operators/LocalShuffle.h b/presto-native-execution/presto_cpp/main/operators/LocalShuffle.h index dfb0e90a1a012..30bc1bd14a659 100644 --- a/presto-native-execution/presto_cpp/main/operators/LocalShuffle.h +++ b/presto-native-execution/presto_cpp/main/operators/LocalShuffle.h @@ -18,6 +18,7 @@ #include #include +#include "velox/common/base/TreeOfLosers.h" #include "velox/common/file/FileSystems.h" #include "velox/exec/Operator.h" @@ -26,6 +27,7 @@ namespace facebook::presto::operators { using TRowSize = uint32_t; + constexpr size_t kUint32Size = sizeof(TRowSize); // Metadata describing a serialized row's location and sizes in a buffer @@ -35,17 +37,18 @@ struct RowMetadata { uint32_t dataSize; // Size of data }; -// Compare two sort keys lexicographically inline bool compareKeys(std::string_view key1, std::string_view key2) noexcept { - return std::lexicographical_compare( - reinterpret_cast(key1.data()), - reinterpret_cast(key1.data() + key1.size()), - reinterpret_cast(key2.data()), - reinterpret_cast(key2.data() + key2.size())); + const auto minSize = std::min(key1.size(), key2.size()); + if (minSize > 0) { + const int cmp = std::memcmp(key1.data(), key2.data(), minSize); + if (cmp != 0) { + return cmp < 0; + } + } + return key1.size() < key2.size(); } -// TODO: Testing function to expose extractRowMetadata for tests. -// This will be removed after reader changes. +// Testing function to expose extractRowMetadata for tests. std::vector testingExtractRowMetadata( const char* buffer, size_t bufferSize, @@ -62,6 +65,7 @@ struct LocalShuffleWriteInfo { std::string queryId; uint32_t numPartitions; uint32_t shuffleId; + bool sortedShuffle; /// Deserializes shuffle information that is used by LocalPersistentShuffle. /// Structures are assumed to be encoded in JSON format. @@ -78,6 +82,7 @@ struct LocalShuffleReadInfo { std::string rootPath; std::string queryId; std::vector partitionIds; + bool sortedShuffle; /// Deserializes shuffle information that is used by LocalPersistentShuffle. /// Structures are assumed to be encoded in JSON format. @@ -166,6 +171,11 @@ class LocalShuffleReader : public ShuffleReader { bool sortedShuffle, velox::memory::MemoryPool* pool); + /// Initializes the reader by discovering shuffle files and setting up merge + /// infrastructure for sorted shuffle. Must be called before next(). + /// For sorted shuffle, this opens all shuffle files and prepares k-way merge. + void initialize(); + folly::SemiFuture>> next( uint64_t maxBytes) override; @@ -180,6 +190,16 @@ class LocalShuffleReader : public ShuffleReader { // Returns all created shuffle files for 'partition_'. std::vector getReadPartitionFiles() const; + // Initializes sorted shuffle read by creating input streams and setting up + // k-way merge infrastructure. + void initSortedShuffleRead(); + + // Reads sorted shuffle data using k-way merge with TreeOfLosers. + std::vector> nextSorted(uint64_t maxBytes); + + // Reads unsorted shuffle data in batch-based file reading. + std::vector> nextUnsorted(uint64_t maxBytes); + const std::string rootPath_; const std::string queryId_; const std::vector partitionIds_; @@ -194,6 +214,11 @@ class LocalShuffleReader : public ShuffleReader { // The top directory of the shuffle files and its file system. std::shared_ptr fileSystem_; + + // Used to merge sorted streams from multiple shuffle files for k-way merge. + std::unique_ptr> merge_; + + bool initialized_{false}; }; class LocalPersistentShuffleFactory : public ShuffleInterfaceFactory { diff --git a/presto-native-execution/presto_cpp/main/operators/tests/ShuffleTest.cpp b/presto-native-execution/presto_cpp/main/operators/tests/ShuffleTest.cpp index d6d5c4c8a2735..086cb47029a20 100644 --- a/presto-native-execution/presto_cpp/main/operators/tests/ShuffleTest.cpp +++ b/presto-native-execution/presto_cpp/main/operators/tests/ShuffleTest.cpp @@ -1338,7 +1338,11 @@ TEST_F(ShuffleTest, persistentShuffleBatch) { pool()); for (auto i = 0; i < numRows; ++i) { - writer->collect(partition, std::string_view{}, views[i]); + writer->collect( + partition, + testData.sortedShuffle ? std::string_view(values[i].data(), 8) + : std::string_view{}, + views[i]); } writer->noMoreData(true); @@ -1352,6 +1356,7 @@ TEST_F(ShuffleTest, persistentShuffleBatch) { readInfo.partitionIds, testData.sortedShuffle, pool()); + reader->initialize(); int numOutputCalls{0}; int numBatches{0}; @@ -1748,6 +1753,161 @@ TEST_F(ShuffleTest, partitionAndSerializeEndToEnd) { runPartitionAndSerializeSerdeTest(data, 2, {{"c2", "c0"}}); } +TEST_F(ShuffleTest, persistentShuffleSortedEndToEnd) { + const uint32_t numPartitions = 1; + const uint32_t partition = 0; + + struct TestConfig { + size_t maxBytesPerPartition; + size_t numRows; + uint64_t readMaxBytes; + size_t minDataSize; + size_t maxDataSize; + std::string debugString() const { + return fmt::format( + "maxBytesPerPartition:{}, rows:{}, readMax:{}, dataSize:{}-{}", + maxBytesPerPartition, + numRows, + readMaxBytes, + minDataSize, + maxDataSize); + } + } testSettings[] = { + {.maxBytesPerPartition = 1024, + .numRows = 1, + .readMaxBytes = 1024, + .minDataSize = 10, + .maxDataSize = 50}, + {.maxBytesPerPartition = 1024, + .numRows = 10, + .readMaxBytes = 1024 * 1024, + .minDataSize = 50, + .maxDataSize = 200}, + {.maxBytesPerPartition = 500, + .numRows = 20, + .readMaxBytes = 1024 * 1024, + .minDataSize = 50, + .maxDataSize = 150}, + {.maxBytesPerPartition = 1024, + .numRows = 50, + .readMaxBytes = 8192, + .minDataSize = 100, + .maxDataSize = 400}, + {.maxBytesPerPartition = 2048, + .numRows = 100, + .readMaxBytes = 1024 * 1024, + .minDataSize = 200, + .maxDataSize = 1000}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto tempRootDir = velox::exec::test::TempDirectoryPath::create(); + const auto testRootPath = tempRootDir->getPath(); + + LocalShuffleWriteInfo writeInfo = LocalShuffleWriteInfo::deserialize( + localShuffleWriteInfo(testRootPath, numPartitions)); + + auto writer = std::make_shared( + writeInfo.rootPath, + writeInfo.queryId, + writeInfo.shuffleId, + writeInfo.numPartitions, + testData.maxBytesPerPartition, + /*sortedShuffle=*/true, + pool()); + + folly::Random::DefaultGenerator rng; + rng.seed(1); + std::vector randomKeys; + randomKeys.reserve(testData.numRows); + std::vector dataValues; + dataValues.reserve(testData.numRows); + + for (size_t i = 0; i < testData.numRows; ++i) { + randomKeys.push_back(static_cast(folly::Random::rand32(rng))); + + const size_t sizeRange = testData.maxDataSize - testData.minDataSize; + const size_t dataSize = testData.minDataSize + + (sizeRange > 0 ? folly::Random::rand32(rng) % sizeRange : 0); + + // Create data with index marker at the end for verification + std::string data(dataSize, static_cast('a' + (i % 26))); + data.append(fmt::format("_idx{:04d}", i)); + dataValues.push_back(std::move(data)); + } + for (size_t i = 0; i < randomKeys.size(); ++i) { + int32_t keyBigEndian = folly::Endian::big(randomKeys[i]); + std::string_view keyBytes( + reinterpret_cast(&keyBigEndian), kUint32Size); + writer->collect(partition, keyBytes, dataValues[i]); + } + writer->noMoreData(true); + + LocalShuffleReadInfo readInfo = LocalShuffleReadInfo::deserialize( + localShuffleReadInfo(testRootPath, numPartitions, partition)); + + auto reader = std::make_shared( + readInfo.rootPath, + readInfo.queryId, + readInfo.partitionIds, + /*sortedShuffle=*/true, + pool()); + reader->initialize(); + + size_t count = 0; + std::vector readDataValues; + + while (true) { + auto batches = reader->next(testData.readMaxBytes) + .via(folly::getGlobalCPUExecutor()) + .get(); + if (batches.empty()) { + break; + } + + for (const auto& batch : batches) { + for (const auto& row : batch->rows) { + readDataValues.emplace_back(row.data(), row.size()); + ++count; + } + } + } + + EXPECT_EQ(randomKeys.size(), count); + + // Get the sorted order of original keys using getSortOrder + std::vector keys; + keys.reserve(randomKeys.size()); + for (const auto& key : randomKeys) { + int32_t keyBigEndian = folly::Endian::big(key); + keys.emplace_back( + reinterpret_cast(&keyBigEndian), sizeof(int32_t)); + } + auto sortedOrder = getSortOrder(keys); + + // Verify data appears in sorted key order + for (size_t i = 0; i < readDataValues.size(); ++i) { + // Extract original index from data value (format: [chars]_idx0000) + const std::string& dataValue = readDataValues[i]; + size_t idxPos = dataValue.find("_idx"); + ASSERT_NE(idxPos, std::string::npos) + << "Data value at position " << i << " missing '_idx' marker: '" + << dataValue << "'"; + + size_t originalIdx = std::stoul(dataValue.substr(idxPos + 4)); + + // The data at position i should correspond to the key at sortedOrder[i] + EXPECT_EQ(originalIdx, sortedOrder[i]) + << "Data at position " << i << " should correspond to key at index " + << sortedOrder[i] << " but corresponds to index " << originalIdx; + } + reader->noMoreData(true); + cleanupDirectory(testRootPath); + } +} + } // namespace facebook::presto::operators::test int main(int argc, char** argv) { diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleInfoTranslator.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleInfoTranslator.java index 0604707c353e3..8746832a2aa54 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleInfoTranslator.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleInfoTranslator.java @@ -50,16 +50,22 @@ public PrestoSparkLocalShuffleInfoTranslator( @Override public PrestoSparkLocalShuffleWriteInfo createShuffleWriteInfo(Session session, PrestoSparkShuffleWriteDescriptor writeDescriptor) { - return new PrestoSparkLocalShuffleWriteInfo(writeDescriptor.getNumPartitions(), session.getQueryId().getId(), writeDescriptor.getShuffleHandle().shuffleId(), localShuffleRootPath); + // TODO: Determine sortedShuffle from PartitionAndSerializeNode plan metadata (sortingOrders/sortingKeys). + // Requires extending PrestoSparkShuffleWriteDescriptor or maintaining shuffleId->sortedShuffle mapping. + boolean sortedShuffle = false; + return new PrestoSparkLocalShuffleWriteInfo(writeDescriptor.getNumPartitions(), session.getQueryId().getId(), writeDescriptor.getShuffleHandle().shuffleId(), localShuffleRootPath, sortedShuffle); } @Override public PrestoSparkLocalShuffleReadInfo createShuffleReadInfo(Session session, PrestoSparkShuffleReadDescriptor readDescriptor) { + // TODO: Determine sortedShuffle from write-side metadata or shuffle descriptor. + boolean sortedShuffle = false; return new PrestoSparkLocalShuffleReadInfo( session.getQueryId().getId(), readDescriptor.getPartitionIds(), - localShuffleRootPath); + localShuffleRootPath, + sortedShuffle); } @Override diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleReadInfo.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleReadInfo.java index ebd3a4b8e08a0..37940e84946ba 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleReadInfo.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleReadInfo.java @@ -36,16 +36,19 @@ public class PrestoSparkLocalShuffleReadInfo // Partition ids which are supposed to be read by given shuffle read. private final List partitionIds; private final String rootPath; + private final boolean sortedShuffle; @JsonCreator public PrestoSparkLocalShuffleReadInfo( @JsonProperty("queryId") String queryId, @JsonProperty("partitionIds") List partitionIds, - @JsonProperty("rootPath") String rootPath) + @JsonProperty("rootPath") String rootPath, + @JsonProperty("sortedShuffle") boolean sortedShuffle) { this.queryId = requireNonNull(queryId, "queryId is null"); this.partitionIds = requireNonNull(partitionIds, "partitionIds is null"); this.rootPath = requireNonNull(rootPath, "rootPath is null"); + this.sortedShuffle = sortedShuffle; } @JsonProperty @@ -66,6 +69,12 @@ public List getPartitionIds() return partitionIds; } + @JsonProperty + public boolean getSortedShuffle() + { + return sortedShuffle; + } + @Override public String toString() { @@ -73,6 +82,7 @@ public String toString() .add("queryId", queryId) .add("partitionIds", String.join(", ", partitionIds)) .add("rootPath", rootPath) + .add("sortedShuffle", sortedShuffle) .toString(); } } diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleWriteInfo.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleWriteInfo.java index 2d8f1bda51683..792e67aff38f5 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleWriteInfo.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/shuffle/PrestoSparkLocalShuffleWriteInfo.java @@ -35,18 +35,21 @@ public class PrestoSparkLocalShuffleWriteInfo // Unique identifier for each shuffle stage generated by Spark private final int shuffleId; private final String rootPath; + private final boolean sortedShuffle; @JsonCreator public PrestoSparkLocalShuffleWriteInfo( @JsonProperty("numPartitions") int numPartitions, @JsonProperty("queryId") String queryId, @JsonProperty("shuffleId") int shuffleId, - @JsonProperty("rootPath") String rootPath) + @JsonProperty("rootPath") String rootPath, + @JsonProperty("sortedShuffle") boolean sortedShuffle) { this.numPartitions = numPartitions; this.queryId = requireNonNull(queryId, "queryId is null"); this.shuffleId = shuffleId; this.rootPath = requireNonNull(rootPath, "rootPath is null"); + this.sortedShuffle = sortedShuffle; } @JsonProperty @@ -73,6 +76,12 @@ public int getShuffleId() return shuffleId; } + @JsonProperty + public boolean getSortedShuffle() + { + return sortedShuffle; + } + @Override public String toString() { @@ -81,6 +90,7 @@ public String toString() .add("queryId", queryId) .add("shuffleId", shuffleId) .add("rootPath", rootPath) + .add("sortedShuffle", sortedShuffle) .toString(); } } diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/TestBatchTaskUpdateRequest.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/TestBatchTaskUpdateRequest.java index c22db6a17b78d..6c8370d6a3e52 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/TestBatchTaskUpdateRequest.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/TestBatchTaskUpdateRequest.java @@ -78,7 +78,7 @@ public void testJsonConversion() PrestoSparkLocalShuffleInfoTranslator shuffleInfoTranslator = new PrestoSparkLocalShuffleInfoTranslator( PRESTO_SPARK_LOCAL_SHUFFLE_READ_INFO_JSON_CODEC, PRESTO_SPARK_LOCAL_SHUFFLE_WRITE_INFO_JSON_CODEC); - PrestoSparkLocalShuffleReadInfo readInfo = new PrestoSparkLocalShuffleReadInfo("test_query_id", ImmutableList.of("shuffle1"), "/dummy/read/path"); + PrestoSparkLocalShuffleReadInfo readInfo = new PrestoSparkLocalShuffleReadInfo("test_query_id", ImmutableList.of("shuffle1"), "/dummy/read/path", false); String stringSerializedReadInfo = shuffleInfoTranslator.createSerializedReadInfo(readInfo); PlanNodeId planNodeId = new PlanNodeId("planNodeId"); @@ -127,8 +127,8 @@ public void testShuffleInfoSerialization() PrestoSparkLocalShuffleInfoTranslator shuffleTranslator = new PrestoSparkLocalShuffleInfoTranslator( PRESTO_SPARK_LOCAL_SHUFFLE_READ_INFO_JSON_CODEC, PRESTO_SPARK_LOCAL_SHUFFLE_WRITE_INFO_JSON_CODEC); - PrestoSparkLocalShuffleReadInfo readInfo = new PrestoSparkLocalShuffleReadInfo("test_query_id", ImmutableList.of("shuffle1"), "/dummy/read/path"); - PrestoSparkLocalShuffleWriteInfo writeInfo = new PrestoSparkLocalShuffleWriteInfo(1, "test_query_id", 0, "/dummy/write/path"); + PrestoSparkLocalShuffleReadInfo readInfo = new PrestoSparkLocalShuffleReadInfo("test_query_id", ImmutableList.of("shuffle1"), "/dummy/read/path", false); + PrestoSparkLocalShuffleWriteInfo writeInfo = new PrestoSparkLocalShuffleWriteInfo(1, "test_query_id", 0, "/dummy/write/path", false); String stringSerializedReadInfo = shuffleTranslator.createSerializedReadInfo(readInfo); String stringSerializedWriteInfo = shuffleTranslator.createSerializedWriteInfo(writeInfo); assertEquals( @@ -136,7 +136,8 @@ public void testShuffleInfoSerialization() "{\n" + " \"queryId\" : \"test_query_id\",\n" + " \"partitionIds\" : [ \"shuffle1\" ],\n" + - " \"rootPath\" : \"/dummy/read/path\"\n" + + " \"rootPath\" : \"/dummy/read/path\",\n" + + " \"sortedShuffle\" : false\n" + "}"); assertEquals( stringSerializedWriteInfo, @@ -144,7 +145,8 @@ public void testShuffleInfoSerialization() " \"numPartitions\" : 1,\n" + " \"queryId\" : \"test_query_id\",\n" + " \"shuffleId\" : 0,\n" + - " \"rootPath\" : \"/dummy/write/path\"\n" + + " \"rootPath\" : \"/dummy/write/path\",\n" + + " \"sortedShuffle\" : false\n" + "}"); }