-
Notifications
You must be signed in to change notification settings - Fork 264
reduce NDS benchmark e2e time by optimizing RapidsShuffleManager #13479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
09cd1de to
e827c6b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala
Outdated
Show resolved
Hide resolved
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
|
build |
| "RAPIDS shuffle time stalled by input stream operations" | ||
|
|
||
| // ThreadLocal variable to track if current task has only one ColumnarBatch | ||
| private val singleBatchTaskLocal: ThreadLocal[Boolean] = new ThreadLocal[Boolean] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets not keep thread locals for this state.
Can we at least use GpuColumnVector.tagAsFinalBatch? I know tagAsFinalBatch is right now only applied in the sort and the coalesce. If the shuffle writer sees the first batch, and it's also final, then we know it's a single-batch task.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, any particular reason why thread local is not favored?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because now the shuffle writer has state in the GpuShuffleExchangeExecBase that it needs to query. My point is not about the thread local per se, but more scribbling away state in different places for something that could be communicated by the batch itself. We are crossing interface boundaries, arguably unnecessarily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a ColumnVectorWithState, because tagAsFinalBatch cannot be used directly in our case
| if (isFirstCall) { | ||
| isFirstCall = false | ||
|
|
||
| if (!iter.hasNext) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have to ensure that hasNext will not cause the whole plan to execute, in every circumstance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the biggest problem may lie in Scan part, typically we'll do real reading in hasNext(). But I do see some similar patterns where we do hasNext to check if it's single patch case:
Does this look like a problem to you ? @abellina
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the biggest problem may lie in Scan part, typically we'll do real reading in hasNext().
Yes this is the type of issue I was expecting to find. Though in the example you provided, the function returns an iterator (CachedGpuBatchIterator), which seems inert in the hasNext (just checks pending.nonEmpty, which was needed to instantiate the iterator object)
What I meant by checking every case is it has to be exhaustive, and can be pretty hard to get right. Sometimes we extend Iterator, sometimes we just do it inline: new Iterator[T] { ... }. I do not know if there are clever ways of finding all the issues here.
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala
Outdated
Show resolved
Hide resolved
|
An 8% win will be amazing. I had a couple of comments around both .hasNext and the limiter release piece. I am curious to know how much of the 8% is because we queued extra work given the .hasNext or how much of it is due to the limiter being released too early (before we close the buffer). |
|
|
||
| // Track compression futures per partition | ||
| val partitionFutures = | ||
| new ConcurrentHashMap[Int, java.util.concurrent.ConcurrentLinkedQueue[Future[Long]]]() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for a single-batch task shouldn't this just be a map of reducer id->Future[Long]?
How many futures do you expect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is in case one a sub partition is further split by spark.rapids.shuffle.partitioning.maxCpuBatchSize
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
|
our spark2a cluster is not available today. I have to wait until tomorrow to collect latest perf numbers on NDS |
|
build |
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
|
build |
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
|
build |
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
|
Hi @abellina and @revans2 , I fixed the CI errors and now the PR is ready to review. My final regression test (https://gist.github.com/binmahone/3af72e5a9377e0e292f03ca477bd761c) shows a 7.3% overall improvement, with some queries improving as much as 20~30%. There is a regression query (query93) but after checking the event log it seems to be related to Scan jitter. Will you please go through the PR again, and if possible also try running the regression test (with your perferred configs) to see if the improvement is reproducible? The flame after this PR is applied:
As we can see the shuffle write part still accounts for a portion (bottom right corner), but not that big portion any more. |
|
I tried another two rounds (in each round, test version and baseline version are both run 3~4 times), and got 9.6% , 9.2% improvement respectively. So overall I would say this PR brings 7.3%~9.6% improvement to NDS |
revans2
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general this looks good to me. My main concern is what happens to the performance in the non-happy case.
To figure this out I ran a simple query with 4 different shuffle implementations in 4 different max partition bytes.
// NOTE --conf spark.shuffle.sort.bypassMergeThreshold=199 for sort-shuffle
spark.conf.set("spark.sql.shuffle.partitions", 200)
//spark.conf.set("spark.sql.files.maxPartitionBytes", "512m")
//spark.conf.set("spark.sql.files.maxPartitionBytes", "1024m")
//spark.conf.set("spark.sql.files.maxPartitionBytes", "2048m")
spark.conf.set("spark.sql.files.maxPartitionBytes", "4096m")
(0 until 5).foreach { _ =>
spark.time(spark.read.parquet("/data/tpcds/SF1000_tmp_parquet/store_sales/").selectExpr("ss_sold_date_sk", "ss_sold_time_sk", "ss_item_sk", "ss_customer_sk", "ss_cdemo_sk", "ss_hdemo_sk", "ss_addr_sk").repartition(col("ss_sold_date_sk"),col("ss_sold_time_sk")).selectExpr("COUNT(ss_sold_date_sk)","COUNT(ss_sold_time_sk)", "COUNT(ss_item_sk)", "COUNT(ss_customer_sk)", "COUNT(ss_cdemo_sk)", "COUNT(ss_hdemo_sk)", "COUNT(ss_addr_sk)").show())
}
at 512m this patch was about 5% faster than the regular multi-threaded shuffle (but with low confidence) this is one shuffle batch per task.
at 1024m this patch was only 3% faster than the regular multi-threaded shuffle (again with low confidence) this is still one shuffle batch per task.
at 2048m this patch was 3% slower than the regular multi-threaded shuffle (yet again with low confidence) This one was about 2.8 batches per task
at 4096 this patch was 14% slower than MT shuffle, 15% slower than sort bypass shuffle, and 13% slower than the sort based shuffle (all of them with high confidence). This one was about 4.8 batches per task.
It looks like this provides a very good speedup for the happy path, but when you check for the next batch and it needs to materialize the whole thing it makes the memory pressure high enough that we spill more/and end up running with less parallelism.
I am not really sure how to mitigate this, as I am not 100% sure yet what is happening. We might be able to copy the data to the host before we check to see if there is a second batch. This should reduce the memory pressure. We might be able to always write the first batch out to disk, and then take the hit to copy it if we guessed wrong, but to make that work we need to ensure that we don't need to decompress the data before writing it out again.
| _ => new OpenByteArrayOutputStream()) | ||
| val originLength = buffer.getCount | ||
|
|
||
| // Serialize + compress to memory buffer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this could also be doing encryption. Very minor.
|
Hi @revans2 , can you please share your steps? I tried to reproduce (detailed steps at https://gist.github.com/binmahone/ba8597843af0bf50cf7e5ec6013c947b) but here's what I got:
As you can see, I'm comparing with two jars: the baseline jar and the pr jar, each app you see in the spark history page is running the same query + same config for 5 times. For max partition bytes = 4g case, the pr jar performs similarly with baseline jar. For max partition bytes = 512m case (i.e. the single-batch case), the pr jar performs better than baseline jar when the RapidsShuffleManager is enabled, but performs similarly with baseline jar with RapidsShuffleManager disabled. This is as expected because the current PR is essentially an enhancement to RapidsShuffleManager's write routine. It's also worth mentioning that I collected the above results on @thirtiseven 's workstation. At first I was collecting on my own workstation, but even with the baseline jar, the duration varies significantly across the 5 runs within the same query and same config. One possible reason is that my disk is already >80% utilized and when SSD is near full it becomes unstable in perf. Anyway tomorrow I'll plug another disk on my workstation, and try again. |
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR optimizes shuffle write performance for the common single-batch scenario by eliminating intermediate temporary files and writing partitions directly to memory buffers, then streaming them to disk. The optimization achieves an 8% reduction in NDS end-to-end time.
Key changes:
- Refactored
RapidsShuffleThreadedWriterBaseto use a new streaming architecture where serialization/compression happens to in-memory buffers (OpenByteArrayOutputStream) followed by sequential writes to the final output file - Single-batch optimization: When a task produces only one GPU batch (common case), partitions are written directly without temporary files
- Multi-batch support: When multiple batches are detected (partition IDs decrease), creates separate partial files and merges them at the end
- Introduced
ColumnVectorWithStatebase class to track batch finality across all column vector types - Added pipeline processing with dedicated writer threads per batch for parallelism
Issues found:
- Critical logic bug in writer thread termination condition (line 378-379) that could cause the writer to attempt processing a non-existent partition
Confidence Score: 2/5
- This PR contains a critical loop termination bug that could cause runtime failures
- The optimization logic is sound and well-documented, but there's a critical bug in the writer thread's loop termination condition (lines 378-379). The condition uses
!=when it should use<, which could cause the writer to attempt processing partitionnumPartitions(which doesn't exist, since partitions are 0-indexed from 0 to numPartitions-1). This would likely cause an IndexOutOfBoundsException or similar error. Additionally, the use of reflection to extract internal fields fromShuffleMapOutputWriteris fragile and could break with Spark version changes. - sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala requires immediate attention to fix the loop termination bug
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala | 2/5 | Major refactoring of shuffle writer with single-batch optimization and multi-batch pipeline support; contains critical loop termination bug |
| sql-plugin/src/main/java/com/nvidia/spark/rapids/ColumnVectorWithState.java | 5/5 | New base class extracting batch state tracking from GpuColumnVectorBase; straightforward refactoring |
| sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsShuffleWriter.scala | 5/5 | Added mapOutputWriters tracking for proper cleanup on error; minimal change |
Sequence Diagram
sequenceDiagram
participant MT as Main Thread
participant Rec as Records Iterator
participant Lim as BytesInFlightLimiter
participant WT as Writer Thread Pool
participant WTh as Batch Writer Thread
participant Buf as Memory Buffers
participant Disk as Output File
Note over MT,Disk: Single Batch Scenario (Optimized Path)
MT->>Rec: hasNext()
Rec-->>MT: true
MT->>Rec: next() → (key, value, partitionId)
MT->>MT: incRefCountAndGetSize(value)
MT->>Lim: acquireOrBlock(recordSize)
Note over Lim: Blocks if maxBytesInFlight exceeded
Lim-->>MT: acquired
MT->>WT: queueWriteTask(slotNum, serializeTask)
WT->>WT: serialize + compress to memory buffer
WT->>Buf: write compressed data
WT-->>MT: Future[(recordSize, compressedSize)]
MT->>WTh: notify writer thread (writerCondition)
par Main Thread continues
MT->>Rec: next() → next partition data
and Writer Thread processes
WTh->>WTh: wait for futures to complete
WTh->>WT: future.get() → sizes
WTh->>Disk: write buffer[partition] to disk
WTh->>Lim: release(recordSize)
end
Note over MT: All records processed
MT->>MT: maxPartitionSeen.set(numPartitions)
MT->>MT: processingComplete.set(true)
MT->>WTh: notify (writerCondition)
WTh->>WTh: finish remaining partitions
WTh->>Disk: write final buffers
WTh-->>MT: complete
MT->>Disk: commitAllPartitions()
Note over MT,Disk: Multi-Batch Scenario
MT->>Rec: detect partitionId < previousMax
MT->>MT: isMultiBatch = true
MT->>WTh: signal batch1 complete
MT->>MT: create new BatchState (batch2)
MT->>MT: spawn new Writer Thread
par Batch1 finalization
WTh->>Disk: write to partial file 1
and Batch2 processing
MT->>Rec: continue with batch2 records
end
Note over MT: All batches complete
MT->>MT: wait for all writer threads
MT->>Disk: mergePartialFiles()
MT->>Disk: commitAllPartitions(merged)
10 files reviewed, 1 comment
| while (!processingComplete.get() || | ||
| currentPartitionToWrite != maxPartitionSeen.get()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: loop termination logic has a potential issue
when processingComplete is true and maxPartitionSeen.get() equals numPartitions, the condition currentPartitionToWrite != maxPartitionSeen.get() will be currentPartitionToWrite != numPartitions, which means the loop will terminate when currentPartitionToWrite == numPartitions
however, partition IDs are 0-indexed, so valid partitions are 0 to numPartitions-1. the writer should process up to and including partition numPartitions-1, then exit when currentPartitionToWrite == numPartitions
but line 532 and 614 set maxPartitionSeen.set(numPartitions) which means the writer will try to write partition numPartitions (which doesn't exist) before terminating
| while (!processingComplete.get() || | |
| currentPartitionToWrite != maxPartitionSeen.get()) { | |
| while (!processingComplete.get() || | |
| currentPartitionToWrite < maxPartitionSeen.get()) { |
|
@binmahone lets keep the > 200 case different from this PR. They are different concerns. |
It's just that, I changed the code so that it can address the "hasNext overhead issue" and ">200 partition issue" at the same time. It becomes difficult to separate. My hope is that we can merge this PR with minimal scope. But with the original commits, I can not address the "hasNext overhead issue" very well. There're many other subtasks in #13673, for those subtasks I'll use a new PR. |
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Optimizes shuffle write performance by eliminating intermediate temp file I/O for sorted partition data. Instead of writing each partition to separate files then merging, the code now serializes/compresses directly to memory buffers and streams to final output, achieving ~8% NDS benchmark improvement.
Key changes:
- Removed sequential temp file merge bottleneck by introducing memory buffering with
OpenByteArrayOutputStream - Added pipeline architecture for multi-batch scenarios where each batch gets independent writer thread
- Implemented partition-ordered streaming where writer thread consumes compressed buffers as compression tasks complete
- Uses reflection to extract
outputTempFileandpartitionLengthsfromLocalDiskShuffleMapOutputWriterinternals for multi-batch merging - Tracks multiple
ShuffleMapOutputWriterinstances inmapOutputWriterscollection for proper cleanup
The optimization exploits the fact that GPU batches produce pre-sorted partition data (partition IDs appear in order like {0,0,1,2,3,3,4...}), allowing direct memory buffering within BytesInFlightLimiter constraints.
Confidence Score: 4/5
- This PR is reasonably safe to merge with careful testing, though the complexity and use of reflection warrant thorough validation
- The implementation is a significant architectural change that improves performance (~8% in NDS benchmarks) but introduces substantial complexity. The code uses reflection to access Spark internals, has intricate multi-threaded coordination logic, and handles multiple execution paths (single vs multi-batch). The previous comment identified a potential loop termination issue that needs verification. The change has been performance-tested but the correctness of edge cases (empty partitions, error handling, resource cleanup) needs careful validation.
- Pay special attention to the writer thread loop termination logic (lines 378-437), multi-batch detection logic (lines 527-554), and reflection-based file extraction (lines 826-859)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala | 4/5 | Major refactor of shuffle writer to optimize disk I/O by writing sorted partition data directly to memory buffers, then streaming to disk. Adds multi-batch pipeline processing with independent writer threads per batch. Uses reflection to access internal Spark fields. |
Sequence Diagram
sequenceDiagram
participant Main as Main Thread
participant Limiter as BytesInFlightLimiter
participant CompPool as Compression Thread Pool
participant Writer as Writer Thread
participant Disk as Disk I/O
Note over Main: Single Batch Case
Main->>Main: detect partition sorted data
Main->>Writer: create BatchState & writer thread
loop For each record
Main->>Main: getPartition(key)
Main->>Limiter: acquireOrBlock(recordSize)
Limiter-->>Main: acquired
Main->>CompPool: queue compression task (slotNum)
Main->>Writer: notify new data available
CompPool->>CompPool: serialize + compress to memory buffer
CompPool->>CompPool: return (recordSize, compressedSize)
Writer->>Writer: wait for partition complete
Writer->>Writer: get futures for partition
Writer->>Writer: future.get() blocks until compression done
Writer->>Disk: write buffer to partition file
Writer->>Limiter: release(recordSize)
end
Main->>Writer: set processingComplete = true
Main->>Writer: set maxPartitionSeen = numPartitions
Main->>Writer: notifyAll()
Writer->>Disk: flush remaining partitions
Writer-->>Main: complete
Note over Main: Multi-Batch Case
Main->>Main: detect partitionId < previousMax
Main->>Main: signal current batch complete
Main->>Writer: set maxPartitionSeen = numPartitions
Main->>Main: create NEW BatchState (pipeline!)
Main->>Main: continue processing next batch
Note over Main,Disk: Batches process in parallel
Main->>Main: wait for all batch writers
Main->>Main: merge partial files (partition by partition)
Main->>Disk: write final merged output
1 file reviewed, 1 comment
| var currentBatch = createBatchState(currentBatchId, mapOutputWriter) | ||
|
|
||
| try { | ||
| while (records.hasNext) { | ||
| val record = records.next() | ||
| val key = record._1 | ||
| val value = record._2 | ||
| val reducePartitionId: Int = partitioner.getPartition(key) | ||
|
|
||
| // Detect multi-batch: partition ID decreased means new batch started | ||
| if (reducePartitionId < previousMaxPartition) { | ||
| if (!isMultiBatch) { | ||
| isMultiBatch = true | ||
| logInfo(s"Detected multi-batch scenario for shuffle $shuffleId, " + | ||
| s"transitioning to pipeline mode") | ||
| } | ||
| } finally { | ||
| // This is in a finally block so that if there is an exception queueing | ||
| // futures, that we will have waited for any queued write future before we call | ||
| // .abort on the map output writer (we had test failures otherwise) | ||
| NvtxRegistry.WAITING_FOR_WRITES { | ||
| try { | ||
| while (writeFutures.nonEmpty) { | ||
| try { | ||
| writeFutures.dequeue().get() | ||
| } catch { | ||
| case ee: ExecutionException => | ||
| // this exception is a wrapper for the underlying exception | ||
| // i.e. `IOException`. The ShuffleWriter.write interface says | ||
| // it can throw these. | ||
| throw ee.getCause | ||
| } | ||
|
|
||
| // Signal current batch is complete (but don't block next batch!) | ||
| currentBatch.maxPartitionSeen.set(numPartitions) | ||
| currentBatch.processingComplete.set(true) | ||
| currentBatch.writerCondition.synchronized { | ||
| currentBatch.writerCondition.notifyAll() | ||
| } | ||
|
|
||
| // Add to list for later finalization | ||
| batchStates += currentBatch | ||
|
|
||
| // Immediately create new batch and continue processing (pipeline!) | ||
| currentBatchId += 1 | ||
| val newWriter = shuffleExecutorComponents.createMapOutputWriter( | ||
| shuffleId, | ||
| mapId, | ||
| numPartitions) | ||
| mapOutputWriters += newWriter // Track for cleanup | ||
| currentBatch = createBatchState(currentBatchId, newWriter) | ||
|
|
||
| previousMaxPartition = -1 | ||
| } | ||
|
|
||
| recordsWritten += 1 | ||
| previousMaxPartition = math.max(previousMaxPartition, reducePartitionId) | ||
|
|
||
| // Get or create futures queue for this partition in current batch | ||
| val futures = currentBatch.partitionFutures.computeIfAbsent(reducePartitionId, | ||
| _ => new java.util.concurrent.CopyOnWriteArrayList[Future[(Long, Long)]]()) | ||
|
|
||
| val (cb, recordSize) = incRefCountAndGetSize(value) | ||
|
|
||
| // Acquire limiter and process compression task immediately | ||
| val waitOnLimiterStart = System.nanoTime() | ||
| limiter.acquireOrBlock(recordSize) | ||
| waitTimeOnLimiterNs += System.nanoTime() - waitOnLimiterStart | ||
|
|
||
| // Get or assign a slot number for this partition to ensure | ||
| // all tasks for the same partition run serially in the same slot | ||
| val slotNum = partitionSlots.computeIfAbsent(reducePartitionId, | ||
| _ => RapidsShuffleInternalManagerBase.getNextWriterSlot) | ||
| val finalCurrentBatch = currentBatch | ||
| val future = RapidsShuffleInternalManagerBase.queueWriteTask(slotNum, () => { | ||
| try { | ||
| withResource(cb) { _ => | ||
| // Get or create buffer for this partition in current batch | ||
| val buffer = finalCurrentBatch.partitionBuffers.computeIfAbsent( | ||
| reducePartitionId, _ => new OpenByteArrayOutputStream()) | ||
| val originLength = buffer.getCount | ||
|
|
||
| // Serialize + compress to memory buffer | ||
| val compressedOutputStream = blockManager.serializerManager.wrapStream( | ||
| ShuffleBlockId(shuffleId, mapId, reducePartitionId), buffer) | ||
|
|
||
| val serializationStream = serializerInstance.serializeStream( | ||
| compressedOutputStream) | ||
| withResource(serializationStream) { serializer => | ||
| serializer.writeKey(key.asInstanceOf[Any]) | ||
| serializer.writeValue(value.asInstanceOf[Any]) | ||
| } | ||
| } finally { | ||
| // cancel all pending futures (only in case of error will we cancel) | ||
| writeFutures.foreach(_.cancel(true /*ok to interrupt*/)) | ||
|
|
||
| // Track total written data size (compressed size) | ||
| val compressedSize = (buffer.getCount - originLength).toLong | ||
| totalCompressedSize.addAndGet(compressedSize) | ||
| (recordSize, compressedSize) | ||
| } | ||
| } catch { | ||
| case e: Exception => { | ||
| logError(s"Exception in compression task for shuffle $shuffleId", e) | ||
| throw e | ||
| } | ||
| } | ||
| }) | ||
|
|
||
| currentBatch.maxPartitionSeen.synchronized { | ||
| futures.add(future) | ||
| currentBatch.maxPartitionSeen.set( | ||
| math.max(currentBatch.maxPartitionSeen.get(), reducePartitionId)) | ||
| } | ||
|
|
||
| // writeTimeNs is an approximation of the amount of time we spent in | ||
| // DiskBlockObjectWriter.write, which involves serializing records and writing them | ||
| // on disk. As we use multiple threads for writing, writeTimeNs is | ||
| // estimated by 'the total amount of time it took to finish processing the entire logic | ||
| // above' minus 'the amount of time it took to do anything expensive other than the | ||
| // serialization and the write. The latter involves computations in upstream execs, | ||
| // ColumnarBatch size estimation, and the time blocked on the limiter. | ||
| val writeTimeNs = (System.nanoTime() - processingStart) - | ||
| records.getIterateTimeNs - batchSizeComputeTimeNs - waitTimeOnLimiterNs | ||
|
|
||
| val combineTimeStart = System.nanoTime() | ||
| val pl = writePartitionedData(mapOutputWriter) | ||
| val combineTimeNs = System.nanoTime() - combineTimeStart | ||
|
|
||
| // add openTime which is also done by Spark, and we are counting | ||
| // in the ioTime later | ||
| writeMetrics.incWriteTime(openTimeNs) | ||
|
|
||
| // At this point, Spark has timed the amount of time it took to write | ||
| // to disk (the IO, per write). But note that when we look at the | ||
| // multi threaded case, this metric is now no longer task-time. | ||
| // Users need to look at "rs. shuffle write time" (shuffleWriteTimeMetric), | ||
| // which does its own calculation at the task-thread level. | ||
| // We use ioTimeNs, however, to get an approximation of serialization time. | ||
| val ioTimeNs = | ||
| writeMetrics.asInstanceOf[ThreadSafeShuffleWriteMetricsReporter].getWriteTime | ||
|
|
||
| // serializationTime is the time spent compressing/encoding batches that wasn't | ||
| // counted in the ioTime | ||
| val totalPerRecordWriteTime = recordWriteTime.get() + ioTimeNs | ||
| val ioRatio = (ioTimeNs.toDouble/totalPerRecordWriteTime) | ||
| val serializationRatio = 1.0 - ioRatio | ||
|
|
||
| // update metrics, note that we expect them to be relative to the task | ||
| ioTimeMetric.foreach(_ += (ioRatio * writeTimeNs).toLong) | ||
| serializationTimeMetric.foreach(_ += (serializationRatio * writeTimeNs).toLong) | ||
| // we add all three here because this metric is meant to show the time | ||
| // we are blocked on writes | ||
| shuffleWriteTimeMetric.foreach(_ += (writeTimeNs + combineTimeNs)) | ||
| shuffleCombineTimeMetric.foreach(_ += combineTimeNs) | ||
| pl | ||
| // Wake up writer thread for current batch | ||
| currentBatch.writerCondition.synchronized { | ||
| currentBatch.writerCondition.notifyAll() | ||
| } | ||
| } | ||
| myMapStatus = Some(getMapStatus(blockManager.shuffleServerId, partLengths, mapId)) | ||
| } catch { | ||
| // taken directly from BypassMergeSortShuffleWriter | ||
| case e: Exception => | ||
|
|
||
| // Mark end of last batch - ensure all partitions are processed | ||
| currentBatch.maxPartitionSeen.set(numPartitions) | ||
| currentBatch.processingComplete.set(true) | ||
| currentBatch.writerCondition.synchronized { | ||
| currentBatch.writerCondition.notifyAll() | ||
| } | ||
|
|
||
| // Add last batch to list | ||
| batchStates += currentBatch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: if exception occurs during the processing loop (lines 520-617), currentBatch won't have been added to batchStates yet (added at line 627), so the finally block won't clean it up
the writer executor, open streams, and buffers in currentBatch could leak
consider adding currentBatch to batchStates immediately after creation at line 517, or ensure cleanup in finally block handles the in-progress batch separately
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
|
build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR optimizes the RAPIDS shuffle manager by implementing streaming parallel processing with pipelined partition writing, achieving an 8% reduction in NDS end-to-end time.
Key changes:
- Single-batch optimization: For tasks producing one GPU batch, rows are already sorted by partition. The new implementation serializes/compresses partitions directly to memory buffers and writes them sequentially to the destination file, eliminating intermediate temporary files per partition.
- Multi-batch support: When multiple batches are detected (partition ID decreases), the system creates independent batch states that can process in pipeline, writing to separate temporary files that are later merged.
- Parallel compression: Multiple compression tasks can run concurrently in a thread pool (partitioned by partition ID to maintain ordering), while a dedicated writer thread for each batch writes completed partitions in order.
- Metrics update: Removed deprecated serialization and write time metrics, added new metrics for limiter wait time and serialization wait time.
The previous comment about loop termination logic at line 378 appears to be incorrect upon analysis. The loop correctly processes partitions 0 through numPartitions-1, using maxPartitionSeen.set(numPartitions) as a sentinel value to indicate completion. The condition currentPartitionToWrite < numPartitions ensures the loop exits after processing partition numPartitions-1.
However, the previous comment about resource cleanup at line 517 identifies a valid issue: if an exception occurs during record processing before currentBatch is added to batchStates (line 541 or 625), the batch's executor, futures, and buffers won't be cleaned up in the finally block.
Confidence Score: 3/5
- This PR has significant performance benefits but contains a resource leak vulnerability that could cause issues under failure scenarios
- Score reflects the substantial refactoring with proven performance gains (8% improvement), but the resource cleanup issue poses a real risk in production. If exceptions occur during batch processing, thread executors and memory buffers may not be properly released. The core shuffle logic appears sound, and the multi-batch detection/pipelining is well-designed, but the exception handling needs improvement before merge.
- Pay close attention to
RapidsShuffleInternalManagerBase.scala- the resource cleanup logic in the finally block needs to handle the in-flightcurrentBatchthat may not yet be added tobatchStates
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala | 3/5 | Major refactoring of shuffle writer to use streaming parallel processing with pipelined partition writing. Resource cleanup issue exists if exception occurs during initial batch processing before batch is added to tracking list. |
| sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala | 5/5 | Removed deprecated shuffle metrics and added new threaded writer metrics for tracking limiter wait time and serialization wait time. |
Sequence Diagram
sequenceDiagram
participant MT as Main Thread
participant L as BytesInFlightLimiter
participant TP as Thread Pool (per partition slot)
participant WT as Writer Thread
participant FS as File System
Note over MT: Single Batch Scenario
MT->>MT: Create BatchState with writer thread
WT->>WT: Start writer thread (waits for data)
loop For each record in batch
MT->>MT: Get partition ID from partitioner
MT->>L: acquireOrBlock(recordSize)
L-->>MT: Acquired
MT->>TP: Queue compression task (by partition slot)
TP->>TP: Serialize + compress to memory buffer
TP-->>MT: Return Future[(recordSize, compressedSize)]
MT->>MT: Add future to partition's futures list
MT->>WT: Notify writer thread
WT->>WT: Check if current partition ready
WT->>TP: Wait for futures to complete
TP-->>WT: Compressed data ready
WT->>FS: Write compressed data to partition stream
WT->>L: release(recordSize)
end
MT->>MT: Signal completion (maxPartitionSeen = numPartitions)
MT->>WT: Notify writer thread
WT->>WT: Process remaining partitions
WT->>FS: Close all partition streams
WT-->>MT: Writer thread completes
MT->>FS: Commit all partitions
Note over MT: Multi-Batch Scenario
MT->>MT: Create BatchState 0
WT->>WT: Start writer thread 0
loop For each record in batch 0
MT->>L: acquireOrBlock(recordSize)
MT->>TP: Queue compression task
MT->>WT: Notify writer thread 0
WT->>FS: Write to temp file 0
end
MT->>MT: Detect partition ID decrease (new batch!)
MT->>MT: Signal batch 0 complete (maxPartitionSeen = numPartitions)
MT->>MT: Create BatchState 1 (pipeline!)
WT->>WT: Start writer thread 1
par Batch 0 writer completes
WT->>FS: Finish writing temp file 0
and Batch 1 processes in parallel
loop For each record in batch 1
MT->>L: acquireOrBlock(recordSize)
MT->>TP: Queue compression task
MT->>WT: Notify writer thread 1
WT->>FS: Write to temp file 1
end
end
MT->>MT: All batches complete
MT->>MT: Merge partial files
loop For each partition 0..N
loop For each partial file
MT->>FS: Copy partition P from partial file to final output
end
end
MT->>FS: Commit final merged output
MT->>MT: Cleanup batch states and temp files
4 files reviewed, no comments
|
NOTE: release/25.12 has been created from main. Please retarget your PR to release/25.12 if it should be included in the release. |
revans2
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I got through all of it this time.
| * This class provides state tracking for batch processing, specifically | ||
| * for tracking whether a batch is final or a sub-partition of the final batch. | ||
| */ | ||
| public abstract class ColumnVectorWithState extends ColumnVector { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change feels a bit premature. I can see that this is extending the state final batch tracking to include a few Host column vectors now. But I don't see anywhere that the final batch tracking is updated to copy to the host buffers. I also don't see anywhere that isKnownFinalBatch is used by this patch. Could we remove it just to make the patch simpler until we have a use for it.
| private val serializer = dep.serializer.newInstance() | ||
| private val transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true) | ||
| private val fileBufferSize = sparkConf.get(config.SHUFFLE_FILE_BUFFER_SIZE).toInt * 1024 | ||
| private val fileBufferSize = 64 << 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why? Anytime I see a magic number I want to understand how we got that number. Also why are we replacing a config for this? If 64k is much better than 32k, okay, but should we have a separate rapids config for this instead? is 64k always the ideal?
| private def createBatchState( | ||
| batchId: Int, | ||
| writer: ShuffleMapOutputWriter): BatchState = { | ||
| import java.util.concurrent.atomic.AtomicInteger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't we already import this? on line 22. And the ones below, except for ThreadFactory, it on line 21? So why not add ThreadFactory to that too?
| mapOutputWriter: ShuffleMapOutputWriter, | ||
| partitionBuffers: ConcurrentHashMap[Int, OpenByteArrayOutputStream], | ||
| partitionFutures: ConcurrentHashMap[Int, | ||
| java.util.concurrent.CopyOnWriteArrayList[Future[(Long, Long)]]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Why use the full name here instead of importing it?
| java.util.concurrent.CopyOnWriteArrayList[Future[(Long, Long)]]], | ||
| partitionBytesProgress: ConcurrentHashMap[Int, Long], | ||
| partitionFuturesProgress: ConcurrentHashMap[Int, Int], | ||
| maxPartitionSeen: java.util.concurrent.atomic.AtomicInteger, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we already imported this. Why use the full name here?
| if (reducePartitionId < previousMaxPartition) { | ||
| if (!isMultiBatch) { | ||
| isMultiBatch = true | ||
| logInfo(s"Detected multi-batch scenario for shuffle $shuffleId, " + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be logDebug?
| } | ||
| } catch { | ||
| case e: Exception => { | ||
| logError(s"Exception in compression task for shuffle $shuffleId", e) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log and rethrow is generally an anti-pattern. What do we need from this? Should we be wrapping the exception and throw that so that it has the shuffleId in it?
| try { | ||
| val waitStart = System.nanoTime() | ||
| batch.writerFuture.get() | ||
| totalSerializationWaitTimeNs += System.nanoTime() - waitStart |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to include the time on an error?
| try { | ||
| in.close() | ||
| } catch { | ||
| case _: Exception => /* ignore */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of just ignoring them can we at least log these? I don't expect the exception case to be common, but if it does show up it might be nice to have the information so we can debug it.
| * commitAllPartitions() as it would rename/move outputTempFile. Instead, we just | ||
| * close the streams and extract the file reference for later merge. | ||
| * | ||
| * NOTE: Uses reflection to access LocalDiskShuffleMapOutputWriter internals. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if it is not that? Have we verified that it will be a LocalDiskShuffleMapOutputWriter before we ever go down this optimization path? I see the exception below that tells us maybe we should write our own. I don't really feel comfortable with putting this in production without some precautions to avoid crashing a user's query.



When I used async-profiler to profile the wall time on one of the Spark2a Worker, I found that a big portion of time is spent on shuffle write, even though we're already using multi-threaded shuffle manager in our NDS benchmark:
The reason is that for each task, after each partition data is written to disk, we'll move all the small files to a large file in a sequential way. There're some potential ways of optimizing this:
We can expect the input iterator to produce a sequence of partitions like {0,0,1,2,3,3,4...}. Within the constraint of memory limiter(BytesInFlightLimiter), we can first serialize & compress partition {0,0,1,2} into memory buffers, and write the memory buffers directly to the destination file. After these buffers are written, we call BytesInFlightLimiter#release to allow {3,3,4} to start working.
This optimization brings about 8% decrease in NDS e2e time on spark2a (full report):
After the optimization, we can still see Task thread stalled by writing thread, but given the flame graph of the writing thread, and the monitoring stats, I believe it's because we're intermittently hitting the limit of disk write speed
Actually we can also parallelize the step of writing memory buffers to final destination file (This requires customization of LocalDiskShuffleMapOutputWriter, to make it thread safe and accept offsets from different partitions) I tried this, but no improvement was observed, possibly because we already hit limit of disk write speed
increase spark.rapids.shuffle.multiThreaded.maxBytesInFlight, tried, no improvement
set spark.io.compression.codec=zstd, as zstd is considered to have a better compression ratio than the default lz, so that we can reduce footprint on disk. I tried, but no improvement.