Skip to content

Conversation

@binmahone
Copy link
Collaborator

@binmahone binmahone commented Sep 24, 2025

image

When I used async-profiler to profile the wall time on one of the Spark2a Worker, I found that a big portion of time is spent on shuffle write, even though we're already using multi-threaded shuffle manager in our NDS benchmark:

"--conf" "spark.shuffle.manager=com.nvidia.spark.rapids.spark320.RapidsShuffleManager"
"--conf" "spark.rapids.shuffle.multiThreaded.writer.threads=32"
"--conf" "spark.rapids.shuffle.multiThreaded.reader.threads=32"
"--conf" "spark.rapids.shuffle.mode=MULTITHREADED"

The reason is that for each task, after each partition data is written to disk, we'll move all the small files to a large file in a sequential way. There're some potential ways of optimizing this:

  1. The workload seems to be bounded on disk write. So it would be nice if we could reduce disk write overheads. Actually, if the task generates only one GPU batch to shuffle writer (BTW this is very common case), then we can fully exploit the fact that the rows in this single batch are already sorted (see below snapshot). With some code changes we can skip the step of writing each partition into small files:
image

We can expect the input iterator to produce a sequence of partitions like {0,0,1,2,3,3,4...}. Within the constraint of memory limiter(BytesInFlightLimiter), we can first serialize & compress partition {0,0,1,2} into memory buffers, and write the memory buffers directly to the destination file. After these buffers are written, we call BytesInFlightLimiter#release to allow {3,3,4} to start working.

This optimization brings about 8% decrease in NDS e2e time on spark2a (full report):

Name = benchmark
Means = 357200.0, 330600.0
Time diff = 26600.0
Speedup = 1.0804597701149425
T-Test (test statistic, p value, df) = 9.778354473847644, 1.0031842174347506e-05, 8.0
T-Test Confidence Interval = 20326.99055239971, 32873.009447600285
ALERT: significant change has been detected (p-value < 0.05)
ALERT: improvement in performance has been observed

After the optimization, we can still see Task thread stalled by writing thread, but given the flame graph of the writing thread, and the monitoring stats, I believe it's because we're intermittently hitting the limit of disk write speed

image

img_v3_02qf_96585889-3a7a-43f1-b0b7-69d6935ba15g

  1. Actually we can also parallelize the step of writing memory buffers to final destination file (This requires customization of LocalDiskShuffleMapOutputWriter, to make it thread safe and accept offsets from different partitions) I tried this, but no improvement was observed, possibly because we already hit limit of disk write speed

  2. increase spark.rapids.shuffle.multiThreaded.maxBytesInFlight, tried, no improvement

  3. set spark.io.compression.codec=zstd, as zstd is considered to have a better compression ratio than the default lz, so that we can reduce footprint on disk. I tried, but no improvement.

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Copilot AI review requested due to automatic review settings September 24, 2025 09:16
@binmahone binmahone marked this pull request as draft September 24, 2025 09:16
@binmahone binmahone self-assigned this Sep 24, 2025

This comment was marked as outdated.

@sameerz sameerz added the performance A performance related task/issue label Sep 24, 2025
@binmahone binmahone changed the title [DO NOT REVIEW] optimize nds [DO NOT REVIEW] optimize NDS benchmark e2e time Sep 25, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.


Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
@binmahone
Copy link
Collaborator Author

build

@binmahone binmahone changed the title [DO NOT REVIEW] optimize NDS benchmark e2e time reduce NDS benchmark e2e time by optimizing RapidsShuffleManager Sep 25, 2025
@binmahone binmahone marked this pull request as ready for review September 25, 2025 07:48
"RAPIDS shuffle time stalled by input stream operations"

// ThreadLocal variable to track if current task has only one ColumnarBatch
private val singleBatchTaskLocal: ThreadLocal[Boolean] = new ThreadLocal[Boolean] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets not keep thread locals for this state.

Can we at least use GpuColumnVector.tagAsFinalBatch? I know tagAsFinalBatch is right now only applied in the sort and the coalesce. If the shuffle writer sees the first batch, and it's also final, then we know it's a single-batch task.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, any particular reason why thread local is not favored?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because now the shuffle writer has state in the GpuShuffleExchangeExecBase that it needs to query. My point is not about the thread local per se, but more scribbling away state in different places for something that could be communicated by the batch itself. We are crossing interface boundaries, arguably unnecessarily.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a ColumnVectorWithState, because tagAsFinalBatch cannot be used directly in our case

if (isFirstCall) {
isFirstCall = false

if (!iter.hasNext) {
Copy link
Collaborator

@abellina abellina Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have to ensure that hasNext will not cause the whole plan to execute, in every circumstance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the biggest problem may lie in Scan part, typically we'll do real reading in hasNext(). But I do see some similar patterns where we do hasNext to check if it's single patch case:

img_v3_02qi_b70079e1-adb8-4921-83ae-2f0f8bd5222g

Does this look like a problem to you ? @abellina

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the biggest problem may lie in Scan part, typically we'll do real reading in hasNext().

Yes this is the type of issue I was expecting to find. Though in the example you provided, the function returns an iterator (CachedGpuBatchIterator), which seems inert in the hasNext (just checks pending.nonEmpty, which was needed to instantiate the iterator object)

What I meant by checking every case is it has to be exhaustive, and can be pretty hard to get right. Sometimes we extend Iterator, sometimes we just do it inline: new Iterator[T] { ... }. I do not know if there are clever ways of finding all the issues here.

@abellina
Copy link
Collaborator

An 8% win will be amazing. I had a couple of comments around both .hasNext and the limiter release piece. I am curious to know how much of the 8% is because we queued extra work given the .hasNext or how much of it is due to the limiter being released too early (before we close the buffer).


// Track compression futures per partition
val partitionFutures =
new ConcurrentHashMap[Int, java.util.concurrent.ConcurrentLinkedQueue[Future[Long]]]()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for a single-batch task shouldn't this just be a map of reducer id->Future[Long]?

How many futures do you expect?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is in case one a sub partition is further split by spark.rapids.shuffle.partitioning.maxCpuBatchSize

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
@binmahone
Copy link
Collaborator Author

our spark2a cluster is not available today. I have to wait until tomorrow to collect latest perf numbers on NDS

@binmahone
Copy link
Collaborator Author

build

@binmahone
Copy link
Collaborator Author

build

@binmahone
Copy link
Collaborator Author

build

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
@binmahone
Copy link
Collaborator Author

binmahone commented Oct 16, 2025

Hi @abellina and @revans2 , I fixed the CI errors and now the PR is ready to review. My final regression test (https://gist.github.com/binmahone/3af72e5a9377e0e292f03ca477bd761c) shows a 7.3% overall improvement, with some queries improving as much as 20~30%. There is a regression query (query93) but after checking the event log it seems to be related to Scan jitter.

Will you please go through the PR again, and if possible also try running the regression test (with your perferred configs) to see if the improvement is reproducible?

The flame after this PR is applied:

image

As we can see the shuffle write part still accounts for a portion (bottom right corner), but not that big portion any more.

@binmahone
Copy link
Collaborator Author

I tried another two rounds (in each round, test version and baseline version are both run 3~4 times), and got 9.6% , 9.2% improvement respectively.

So overall I would say this PR brings 7.3%~9.6% improvement to NDS

Copy link
Collaborator

@revans2 revans2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general this looks good to me. My main concern is what happens to the performance in the non-happy case.

To figure this out I ran a simple query with 4 different shuffle implementations in 4 different max partition bytes.

// NOTE --conf spark.shuffle.sort.bypassMergeThreshold=199 for sort-shuffle
spark.conf.set("spark.sql.shuffle.partitions", 200)
//spark.conf.set("spark.sql.files.maxPartitionBytes", "512m")
//spark.conf.set("spark.sql.files.maxPartitionBytes", "1024m")
//spark.conf.set("spark.sql.files.maxPartitionBytes", "2048m")
spark.conf.set("spark.sql.files.maxPartitionBytes", "4096m")
(0 until 5).foreach { _ =>
  spark.time(spark.read.parquet("/data/tpcds/SF1000_tmp_parquet/store_sales/").selectExpr("ss_sold_date_sk", "ss_sold_time_sk", "ss_item_sk", "ss_customer_sk", "ss_cdemo_sk", "ss_hdemo_sk", "ss_addr_sk").repartition(col("ss_sold_date_sk"),col("ss_sold_time_sk")).selectExpr("COUNT(ss_sold_date_sk)","COUNT(ss_sold_time_sk)", "COUNT(ss_item_sk)", "COUNT(ss_customer_sk)", "COUNT(ss_cdemo_sk)", "COUNT(ss_hdemo_sk)", "COUNT(ss_addr_sk)").show())
}

at 512m this patch was about 5% faster than the regular multi-threaded shuffle (but with low confidence) this is one shuffle batch per task.

at 1024m this patch was only 3% faster than the regular multi-threaded shuffle (again with low confidence) this is still one shuffle batch per task.

at 2048m this patch was 3% slower than the regular multi-threaded shuffle (yet again with low confidence) This one was about 2.8 batches per task

at 4096 this patch was 14% slower than MT shuffle, 15% slower than sort bypass shuffle, and 13% slower than the sort based shuffle (all of them with high confidence). This one was about 4.8 batches per task.

It looks like this provides a very good speedup for the happy path, but when you check for the next batch and it needs to materialize the whole thing it makes the memory pressure high enough that we spill more/and end up running with less parallelism.

I am not really sure how to mitigate this, as I am not 100% sure yet what is happening. We might be able to copy the data to the host before we check to see if there is a second batch. This should reduce the memory pressure. We might be able to always write the first batch out to disk, and then take the hit to copy it if we guessed wrong, but to make that work we need to ensure that we don't need to decompress the data before writing it out again.

_ => new OpenByteArrayOutputStream())
val originLength = buffer.getCount

// Serialize + compress to memory buffer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this could also be doing encryption. Very minor.

@binmahone
Copy link
Collaborator Author

Hi @revans2 , can you please share your steps? I tried to reproduce (detailed steps at https://gist.github.com/binmahone/ba8597843af0bf50cf7e5ec6013c947b) but here's what I got:

image

As you can see, I'm comparing with two jars: the baseline jar and the pr jar, each app you see in the spark history page is running the same query + same config for 5 times.

For max partition bytes = 4g case, the pr jar performs similarly with baseline jar.

For max partition bytes = 512m case (i.e. the single-batch case), the pr jar performs better than baseline jar when the RapidsShuffleManager is enabled, but performs similarly with baseline jar with RapidsShuffleManager disabled. This is as expected because the current PR is essentially an enhancement to RapidsShuffleManager's write routine.

It's also worth mentioning that I collected the above results on @thirtiseven 's workstation. At first I was collecting on my own workstation, but even with the baseline jar, the duration varies significantly across the 5 runs within the same query and same config. One possible reason is that my disk is already >80% utilized and when SSD is near full it becomes unstable in perf. Anyway tomorrow I'll plug another disk on my workstation, and try again.

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR optimizes shuffle write performance for the common single-batch scenario by eliminating intermediate temporary files and writing partitions directly to memory buffers, then streaming them to disk. The optimization achieves an 8% reduction in NDS end-to-end time.

Key changes:

  • Refactored RapidsShuffleThreadedWriterBase to use a new streaming architecture where serialization/compression happens to in-memory buffers (OpenByteArrayOutputStream) followed by sequential writes to the final output file
  • Single-batch optimization: When a task produces only one GPU batch (common case), partitions are written directly without temporary files
  • Multi-batch support: When multiple batches are detected (partition IDs decrease), creates separate partial files and merges them at the end
  • Introduced ColumnVectorWithState base class to track batch finality across all column vector types
  • Added pipeline processing with dedicated writer threads per batch for parallelism

Issues found:

  • Critical logic bug in writer thread termination condition (line 378-379) that could cause the writer to attempt processing a non-existent partition

Confidence Score: 2/5

  • This PR contains a critical loop termination bug that could cause runtime failures
  • The optimization logic is sound and well-documented, but there's a critical bug in the writer thread's loop termination condition (lines 378-379). The condition uses != when it should use <, which could cause the writer to attempt processing partition numPartitions (which doesn't exist, since partitions are 0-indexed from 0 to numPartitions-1). This would likely cause an IndexOutOfBoundsException or similar error. Additionally, the use of reflection to extract internal fields from ShuffleMapOutputWriter is fragile and could break with Spark version changes.
  • sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala requires immediate attention to fix the loop termination bug

Important Files Changed

File Analysis

Filename Score Overview
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala 2/5 Major refactoring of shuffle writer with single-batch optimization and multi-batch pipeline support; contains critical loop termination bug
sql-plugin/src/main/java/com/nvidia/spark/rapids/ColumnVectorWithState.java 5/5 New base class extracting batch state tracking from GpuColumnVectorBase; straightforward refactoring
sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsShuffleWriter.scala 5/5 Added mapOutputWriters tracking for proper cleanup on error; minimal change

Sequence Diagram

sequenceDiagram
    participant MT as Main Thread
    participant Rec as Records Iterator
    participant Lim as BytesInFlightLimiter
    participant WT as Writer Thread Pool
    participant WTh as Batch Writer Thread
    participant Buf as Memory Buffers
    participant Disk as Output File

    Note over MT,Disk: Single Batch Scenario (Optimized Path)
    
    MT->>Rec: hasNext()
    Rec-->>MT: true
    MT->>Rec: next() → (key, value, partitionId)
    
    MT->>MT: incRefCountAndGetSize(value)
    MT->>Lim: acquireOrBlock(recordSize)
    Note over Lim: Blocks if maxBytesInFlight exceeded
    Lim-->>MT: acquired
    
    MT->>WT: queueWriteTask(slotNum, serializeTask)
    WT->>WT: serialize + compress to memory buffer
    WT->>Buf: write compressed data
    WT-->>MT: Future[(recordSize, compressedSize)]
    
    MT->>WTh: notify writer thread (writerCondition)
    
    par Main Thread continues
        MT->>Rec: next() → next partition data
    and Writer Thread processes
        WTh->>WTh: wait for futures to complete
        WTh->>WT: future.get() → sizes
        WTh->>Disk: write buffer[partition] to disk
        WTh->>Lim: release(recordSize)
    end
    
    Note over MT: All records processed
    MT->>MT: maxPartitionSeen.set(numPartitions)
    MT->>MT: processingComplete.set(true)
    MT->>WTh: notify (writerCondition)
    
    WTh->>WTh: finish remaining partitions
    WTh->>Disk: write final buffers
    WTh-->>MT: complete
    
    MT->>Disk: commitAllPartitions()
    
    Note over MT,Disk: Multi-Batch Scenario
    
    MT->>Rec: detect partitionId < previousMax
    MT->>MT: isMultiBatch = true
    MT->>WTh: signal batch1 complete
    MT->>MT: create new BatchState (batch2)
    MT->>MT: spawn new Writer Thread
    
    par Batch1 finalization
        WTh->>Disk: write to partial file 1
    and Batch2 processing
        MT->>Rec: continue with batch2 records
    end
    
    Note over MT: All batches complete
    MT->>MT: wait for all writer threads
    MT->>Disk: mergePartialFiles()
    MT->>Disk: commitAllPartitions(merged)
Loading

10 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 378 to 379
while (!processingComplete.get() ||
currentPartitionToWrite != maxPartitionSeen.get()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: loop termination logic has a potential issue

when processingComplete is true and maxPartitionSeen.get() equals numPartitions, the condition currentPartitionToWrite != maxPartitionSeen.get() will be currentPartitionToWrite != numPartitions, which means the loop will terminate when currentPartitionToWrite == numPartitions

however, partition IDs are 0-indexed, so valid partitions are 0 to numPartitions-1. the writer should process up to and including partition numPartitions-1, then exit when currentPartitionToWrite == numPartitions

but line 532 and 614 set maxPartitionSeen.set(numPartitions) which means the writer will try to write partition numPartitions (which doesn't exist) before terminating

Suggested change
while (!processingComplete.get() ||
currentPartitionToWrite != maxPartitionSeen.get()) {
while (!processingComplete.get() ||
currentPartitionToWrite < maxPartitionSeen.get()) {

@abellina
Copy link
Collaborator

@binmahone lets keep the > 200 case different from this PR. They are different concerns.

f4dc778

@binmahone
Copy link
Collaborator Author

@binmahone lets keep the > 200 case different from this PR. They are different concerns.

f4dc778

It's just that, I changed the code so that it can address the "hasNext overhead issue" and ">200 partition issue" at the same time. It becomes difficult to separate.

My hope is that we can merge this PR with minimal scope. But with the original commits, I can not address the "hasNext overhead issue" very well.

There're many other subtasks in #13673, for those subtasks I'll use a new PR.

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Optimizes shuffle write performance by eliminating intermediate temp file I/O for sorted partition data. Instead of writing each partition to separate files then merging, the code now serializes/compresses directly to memory buffers and streams to final output, achieving ~8% NDS benchmark improvement.

Key changes:

  • Removed sequential temp file merge bottleneck by introducing memory buffering with OpenByteArrayOutputStream
  • Added pipeline architecture for multi-batch scenarios where each batch gets independent writer thread
  • Implemented partition-ordered streaming where writer thread consumes compressed buffers as compression tasks complete
  • Uses reflection to extract outputTempFile and partitionLengths from LocalDiskShuffleMapOutputWriter internals for multi-batch merging
  • Tracks multiple ShuffleMapOutputWriter instances in mapOutputWriters collection for proper cleanup

The optimization exploits the fact that GPU batches produce pre-sorted partition data (partition IDs appear in order like {0,0,1,2,3,3,4...}), allowing direct memory buffering within BytesInFlightLimiter constraints.

Confidence Score: 4/5

  • This PR is reasonably safe to merge with careful testing, though the complexity and use of reflection warrant thorough validation
  • The implementation is a significant architectural change that improves performance (~8% in NDS benchmarks) but introduces substantial complexity. The code uses reflection to access Spark internals, has intricate multi-threaded coordination logic, and handles multiple execution paths (single vs multi-batch). The previous comment identified a potential loop termination issue that needs verification. The change has been performance-tested but the correctness of edge cases (empty partitions, error handling, resource cleanup) needs careful validation.
  • Pay special attention to the writer thread loop termination logic (lines 378-437), multi-batch detection logic (lines 527-554), and reflection-based file extraction (lines 826-859)

Important Files Changed

File Analysis

Filename Score Overview
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala 4/5 Major refactor of shuffle writer to optimize disk I/O by writing sorted partition data directly to memory buffers, then streaming to disk. Adds multi-batch pipeline processing with independent writer threads per batch. Uses reflection to access internal Spark fields.

Sequence Diagram

sequenceDiagram
    participant Main as Main Thread
    participant Limiter as BytesInFlightLimiter
    participant CompPool as Compression Thread Pool
    participant Writer as Writer Thread
    participant Disk as Disk I/O

    Note over Main: Single Batch Case
    Main->>Main: detect partition sorted data
    Main->>Writer: create BatchState & writer thread
    
    loop For each record
        Main->>Main: getPartition(key)
        Main->>Limiter: acquireOrBlock(recordSize)
        Limiter-->>Main: acquired
        Main->>CompPool: queue compression task (slotNum)
        Main->>Writer: notify new data available
        
        CompPool->>CompPool: serialize + compress to memory buffer
        CompPool->>CompPool: return (recordSize, compressedSize)
        
        Writer->>Writer: wait for partition complete
        Writer->>Writer: get futures for partition
        Writer->>Writer: future.get() blocks until compression done
        Writer->>Disk: write buffer to partition file
        Writer->>Limiter: release(recordSize)
    end
    
    Main->>Writer: set processingComplete = true
    Main->>Writer: set maxPartitionSeen = numPartitions
    Main->>Writer: notifyAll()
    Writer->>Disk: flush remaining partitions
    Writer-->>Main: complete
    
    Note over Main: Multi-Batch Case
    Main->>Main: detect partitionId < previousMax
    Main->>Main: signal current batch complete
    Main->>Writer: set maxPartitionSeen = numPartitions
    Main->>Main: create NEW BatchState (pipeline!)
    Main->>Main: continue processing next batch
    
    Note over Main,Disk: Batches process in parallel
    Main->>Main: wait for all batch writers
    Main->>Main: merge partial files (partition by partition)
    Main->>Disk: write final merged output
Loading

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 517 to 627
var currentBatch = createBatchState(currentBatchId, mapOutputWriter)

try {
while (records.hasNext) {
val record = records.next()
val key = record._1
val value = record._2
val reducePartitionId: Int = partitioner.getPartition(key)

// Detect multi-batch: partition ID decreased means new batch started
if (reducePartitionId < previousMaxPartition) {
if (!isMultiBatch) {
isMultiBatch = true
logInfo(s"Detected multi-batch scenario for shuffle $shuffleId, " +
s"transitioning to pipeline mode")
}
} finally {
// This is in a finally block so that if there is an exception queueing
// futures, that we will have waited for any queued write future before we call
// .abort on the map output writer (we had test failures otherwise)
NvtxRegistry.WAITING_FOR_WRITES {
try {
while (writeFutures.nonEmpty) {
try {
writeFutures.dequeue().get()
} catch {
case ee: ExecutionException =>
// this exception is a wrapper for the underlying exception
// i.e. `IOException`. The ShuffleWriter.write interface says
// it can throw these.
throw ee.getCause
}

// Signal current batch is complete (but don't block next batch!)
currentBatch.maxPartitionSeen.set(numPartitions)
currentBatch.processingComplete.set(true)
currentBatch.writerCondition.synchronized {
currentBatch.writerCondition.notifyAll()
}

// Add to list for later finalization
batchStates += currentBatch

// Immediately create new batch and continue processing (pipeline!)
currentBatchId += 1
val newWriter = shuffleExecutorComponents.createMapOutputWriter(
shuffleId,
mapId,
numPartitions)
mapOutputWriters += newWriter // Track for cleanup
currentBatch = createBatchState(currentBatchId, newWriter)

previousMaxPartition = -1
}

recordsWritten += 1
previousMaxPartition = math.max(previousMaxPartition, reducePartitionId)

// Get or create futures queue for this partition in current batch
val futures = currentBatch.partitionFutures.computeIfAbsent(reducePartitionId,
_ => new java.util.concurrent.CopyOnWriteArrayList[Future[(Long, Long)]]())

val (cb, recordSize) = incRefCountAndGetSize(value)

// Acquire limiter and process compression task immediately
val waitOnLimiterStart = System.nanoTime()
limiter.acquireOrBlock(recordSize)
waitTimeOnLimiterNs += System.nanoTime() - waitOnLimiterStart

// Get or assign a slot number for this partition to ensure
// all tasks for the same partition run serially in the same slot
val slotNum = partitionSlots.computeIfAbsent(reducePartitionId,
_ => RapidsShuffleInternalManagerBase.getNextWriterSlot)
val finalCurrentBatch = currentBatch
val future = RapidsShuffleInternalManagerBase.queueWriteTask(slotNum, () => {
try {
withResource(cb) { _ =>
// Get or create buffer for this partition in current batch
val buffer = finalCurrentBatch.partitionBuffers.computeIfAbsent(
reducePartitionId, _ => new OpenByteArrayOutputStream())
val originLength = buffer.getCount

// Serialize + compress to memory buffer
val compressedOutputStream = blockManager.serializerManager.wrapStream(
ShuffleBlockId(shuffleId, mapId, reducePartitionId), buffer)

val serializationStream = serializerInstance.serializeStream(
compressedOutputStream)
withResource(serializationStream) { serializer =>
serializer.writeKey(key.asInstanceOf[Any])
serializer.writeValue(value.asInstanceOf[Any])
}
} finally {
// cancel all pending futures (only in case of error will we cancel)
writeFutures.foreach(_.cancel(true /*ok to interrupt*/))

// Track total written data size (compressed size)
val compressedSize = (buffer.getCount - originLength).toLong
totalCompressedSize.addAndGet(compressedSize)
(recordSize, compressedSize)
}
} catch {
case e: Exception => {
logError(s"Exception in compression task for shuffle $shuffleId", e)
throw e
}
}
})

currentBatch.maxPartitionSeen.synchronized {
futures.add(future)
currentBatch.maxPartitionSeen.set(
math.max(currentBatch.maxPartitionSeen.get(), reducePartitionId))
}

// writeTimeNs is an approximation of the amount of time we spent in
// DiskBlockObjectWriter.write, which involves serializing records and writing them
// on disk. As we use multiple threads for writing, writeTimeNs is
// estimated by 'the total amount of time it took to finish processing the entire logic
// above' minus 'the amount of time it took to do anything expensive other than the
// serialization and the write. The latter involves computations in upstream execs,
// ColumnarBatch size estimation, and the time blocked on the limiter.
val writeTimeNs = (System.nanoTime() - processingStart) -
records.getIterateTimeNs - batchSizeComputeTimeNs - waitTimeOnLimiterNs

val combineTimeStart = System.nanoTime()
val pl = writePartitionedData(mapOutputWriter)
val combineTimeNs = System.nanoTime() - combineTimeStart

// add openTime which is also done by Spark, and we are counting
// in the ioTime later
writeMetrics.incWriteTime(openTimeNs)

// At this point, Spark has timed the amount of time it took to write
// to disk (the IO, per write). But note that when we look at the
// multi threaded case, this metric is now no longer task-time.
// Users need to look at "rs. shuffle write time" (shuffleWriteTimeMetric),
// which does its own calculation at the task-thread level.
// We use ioTimeNs, however, to get an approximation of serialization time.
val ioTimeNs =
writeMetrics.asInstanceOf[ThreadSafeShuffleWriteMetricsReporter].getWriteTime

// serializationTime is the time spent compressing/encoding batches that wasn't
// counted in the ioTime
val totalPerRecordWriteTime = recordWriteTime.get() + ioTimeNs
val ioRatio = (ioTimeNs.toDouble/totalPerRecordWriteTime)
val serializationRatio = 1.0 - ioRatio

// update metrics, note that we expect them to be relative to the task
ioTimeMetric.foreach(_ += (ioRatio * writeTimeNs).toLong)
serializationTimeMetric.foreach(_ += (serializationRatio * writeTimeNs).toLong)
// we add all three here because this metric is meant to show the time
// we are blocked on writes
shuffleWriteTimeMetric.foreach(_ += (writeTimeNs + combineTimeNs))
shuffleCombineTimeMetric.foreach(_ += combineTimeNs)
pl
// Wake up writer thread for current batch
currentBatch.writerCondition.synchronized {
currentBatch.writerCondition.notifyAll()
}
}
myMapStatus = Some(getMapStatus(blockManager.shuffleServerId, partLengths, mapId))
} catch {
// taken directly from BypassMergeSortShuffleWriter
case e: Exception =>

// Mark end of last batch - ensure all partitions are processed
currentBatch.maxPartitionSeen.set(numPartitions)
currentBatch.processingComplete.set(true)
currentBatch.writerCondition.synchronized {
currentBatch.writerCondition.notifyAll()
}

// Add last batch to list
batchStates += currentBatch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: if exception occurs during the processing loop (lines 520-617), currentBatch won't have been added to batchStates yet (added at line 627), so the finally block won't clean it up

the writer executor, open streams, and buffers in currentBatch could leak

consider adding currentBatch to batchStates immediately after creation at line 517, or ensure cleanup in finally block handles the in-progress batch separately

Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
@binmahone
Copy link
Collaborator Author

build

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR optimizes the RAPIDS shuffle manager by implementing streaming parallel processing with pipelined partition writing, achieving an 8% reduction in NDS end-to-end time.

Key changes:

  • Single-batch optimization: For tasks producing one GPU batch, rows are already sorted by partition. The new implementation serializes/compresses partitions directly to memory buffers and writes them sequentially to the destination file, eliminating intermediate temporary files per partition.
  • Multi-batch support: When multiple batches are detected (partition ID decreases), the system creates independent batch states that can process in pipeline, writing to separate temporary files that are later merged.
  • Parallel compression: Multiple compression tasks can run concurrently in a thread pool (partitioned by partition ID to maintain ordering), while a dedicated writer thread for each batch writes completed partitions in order.
  • Metrics update: Removed deprecated serialization and write time metrics, added new metrics for limiter wait time and serialization wait time.

The previous comment about loop termination logic at line 378 appears to be incorrect upon analysis. The loop correctly processes partitions 0 through numPartitions-1, using maxPartitionSeen.set(numPartitions) as a sentinel value to indicate completion. The condition currentPartitionToWrite < numPartitions ensures the loop exits after processing partition numPartitions-1.

However, the previous comment about resource cleanup at line 517 identifies a valid issue: if an exception occurs during record processing before currentBatch is added to batchStates (line 541 or 625), the batch's executor, futures, and buffers won't be cleaned up in the finally block.

Confidence Score: 3/5

  • This PR has significant performance benefits but contains a resource leak vulnerability that could cause issues under failure scenarios
  • Score reflects the substantial refactoring with proven performance gains (8% improvement), but the resource cleanup issue poses a real risk in production. If exceptions occur during batch processing, thread executors and memory buffers may not be properly released. The core shuffle logic appears sound, and the multi-batch detection/pipelining is well-designed, but the exception handling needs improvement before merge.
  • Pay close attention to RapidsShuffleInternalManagerBase.scala - the resource cleanup logic in the finally block needs to handle the in-flight currentBatch that may not yet be added to batchStates

Important Files Changed

File Analysis

Filename Score Overview
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala 3/5 Major refactoring of shuffle writer to use streaming parallel processing with pipelined partition writing. Resource cleanup issue exists if exception occurs during initial batch processing before batch is added to tracking list.
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala 5/5 Removed deprecated shuffle metrics and added new threaded writer metrics for tracking limiter wait time and serialization wait time.

Sequence Diagram

sequenceDiagram
    participant MT as Main Thread
    participant L as BytesInFlightLimiter
    participant TP as Thread Pool (per partition slot)
    participant WT as Writer Thread
    participant FS as File System

    Note over MT: Single Batch Scenario
    MT->>MT: Create BatchState with writer thread
    WT->>WT: Start writer thread (waits for data)
    
    loop For each record in batch
        MT->>MT: Get partition ID from partitioner
        MT->>L: acquireOrBlock(recordSize)
        L-->>MT: Acquired
        MT->>TP: Queue compression task (by partition slot)
        TP->>TP: Serialize + compress to memory buffer
        TP-->>MT: Return Future[(recordSize, compressedSize)]
        MT->>MT: Add future to partition's futures list
        MT->>WT: Notify writer thread
        
        WT->>WT: Check if current partition ready
        WT->>TP: Wait for futures to complete
        TP-->>WT: Compressed data ready
        WT->>FS: Write compressed data to partition stream
        WT->>L: release(recordSize)
    end
    
    MT->>MT: Signal completion (maxPartitionSeen = numPartitions)
    MT->>WT: Notify writer thread
    WT->>WT: Process remaining partitions
    WT->>FS: Close all partition streams
    WT-->>MT: Writer thread completes
    MT->>FS: Commit all partitions
    
    Note over MT: Multi-Batch Scenario
    MT->>MT: Create BatchState 0
    WT->>WT: Start writer thread 0
    
    loop For each record in batch 0
        MT->>L: acquireOrBlock(recordSize)
        MT->>TP: Queue compression task
        MT->>WT: Notify writer thread 0
        WT->>FS: Write to temp file 0
    end
    
    MT->>MT: Detect partition ID decrease (new batch!)
    MT->>MT: Signal batch 0 complete (maxPartitionSeen = numPartitions)
    MT->>MT: Create BatchState 1 (pipeline!)
    WT->>WT: Start writer thread 1
    
    par Batch 0 writer completes
        WT->>FS: Finish writing temp file 0
    and Batch 1 processes in parallel
        loop For each record in batch 1
            MT->>L: acquireOrBlock(recordSize)
            MT->>TP: Queue compression task
            MT->>WT: Notify writer thread 1
            WT->>FS: Write to temp file 1
        end
    end
    
    MT->>MT: All batches complete
    MT->>MT: Merge partial files
    
    loop For each partition 0..N
        loop For each partial file
            MT->>FS: Copy partition P from partial file to final output
        end
    end
    
    MT->>FS: Commit final merged output
    MT->>MT: Cleanup batch states and temp files
Loading

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@nvauto
Copy link
Collaborator

nvauto commented Nov 17, 2025

NOTE: release/25.12 has been created from main. Please retarget your PR to release/25.12 if it should be included in the release.

Copy link
Collaborator

@revans2 revans2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I got through all of it this time.

* This class provides state tracking for batch processing, specifically
* for tracking whether a batch is final or a sub-partition of the final batch.
*/
public abstract class ColumnVectorWithState extends ColumnVector {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change feels a bit premature. I can see that this is extending the state final batch tracking to include a few Host column vectors now. But I don't see anywhere that the final batch tracking is updated to copy to the host buffers. I also don't see anywhere that isKnownFinalBatch is used by this patch. Could we remove it just to make the patch simpler until we have a use for it.

private val serializer = dep.serializer.newInstance()
private val transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true)
private val fileBufferSize = sparkConf.get(config.SHUFFLE_FILE_BUFFER_SIZE).toInt * 1024
private val fileBufferSize = 64 << 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? Anytime I see a magic number I want to understand how we got that number. Also why are we replacing a config for this? If 64k is much better than 32k, okay, but should we have a separate rapids config for this instead? is 64k always the ideal?

private def createBatchState(
batchId: Int,
writer: ShuffleMapOutputWriter): BatchState = {
import java.util.concurrent.atomic.AtomicInteger
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we already import this? on line 22. And the ones below, except for ThreadFactory, it on line 21? So why not add ThreadFactory to that too?

mapOutputWriter: ShuffleMapOutputWriter,
partitionBuffers: ConcurrentHashMap[Int, OpenByteArrayOutputStream],
partitionFutures: ConcurrentHashMap[Int,
java.util.concurrent.CopyOnWriteArrayList[Future[(Long, Long)]]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Why use the full name here instead of importing it?

java.util.concurrent.CopyOnWriteArrayList[Future[(Long, Long)]]],
partitionBytesProgress: ConcurrentHashMap[Int, Long],
partitionFuturesProgress: ConcurrentHashMap[Int, Int],
maxPartitionSeen: java.util.concurrent.atomic.AtomicInteger,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we already imported this. Why use the full name here?

if (reducePartitionId < previousMaxPartition) {
if (!isMultiBatch) {
isMultiBatch = true
logInfo(s"Detected multi-batch scenario for shuffle $shuffleId, " +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be logDebug?

}
} catch {
case e: Exception => {
logError(s"Exception in compression task for shuffle $shuffleId", e)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log and rethrow is generally an anti-pattern. What do we need from this? Should we be wrapping the exception and throw that so that it has the shuffleId in it?

try {
val waitStart = System.nanoTime()
batch.writerFuture.get()
totalSerializationWaitTimeNs += System.nanoTime() - waitStart
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to include the time on an error?

try {
in.close()
} catch {
case _: Exception => /* ignore */
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of just ignoring them can we at least log these? I don't expect the exception case to be common, but if it does show up it might be nice to have the information so we can debug it.

* commitAllPartitions() as it would rename/move outputTempFile. Instead, we just
* close the streams and extract the file reference for later merge.
*
* NOTE: Uses reflection to access LocalDiskShuffleMapOutputWriter internals.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if it is not that? Have we verified that it will be a LocalDiskShuffleMapOutputWriter before we ever go down this optimization path? I see the exception below that tells us maybe we should write our own. I don't really feel comfortable with putting this in production without some precautions to avoid crashing a user's query.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance A performance related task/issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants