diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh index 047c8b73f5e..eb63174d6a9 100755 --- a/integration_tests/run_pyspark_from_build.sh +++ b/integration_tests/run_pyspark_from_build.sh @@ -634,6 +634,11 @@ PY unset PYSP_TEST_spark_jars_repositories unset PYSP_TEST_spark_rapids_memory_gpu_allocSize + + # Comment this out if you want to run remote debug this local mode spark process + # Don't forget to set TEST_PARALLEL=1 to ensure local mode spark + # export SPARK_SUBMIT_OPTS="-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=5005" + exec "$SPARK_HOME"/bin/spark-submit "${jarOpts[@]}" \ --driver-java-options "$driverJavaOpts" \ $SPARK_SUBMIT_FLAGS \ diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala index 83028d6c6be..03841c21677 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2025, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -312,7 +312,7 @@ trait GpuExec extends SparkPlan with Logging { // For GpuShuffleExchangeExecBase and GpuCustomShuffleReaderExec, // we want the op time metrics to be called: - // - "op time (shuffle write partition & serial)" for shuffle write, and + // - "op time (shuffle write partition)" for shuffle write, and // - "op time (shuffle read)" for shuffle read. // That's why we have this separate method to get the metric. def getOpTimeNewMetric: Option[GpuMetric] = allMetrics.get(OP_TIME_NEW) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala index 9b83aa638c4..1ff4aed9b22 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala @@ -155,7 +155,7 @@ object GpuMetric extends Logging { val DESCRIPTION_NUM_PARTITIONS = "partitions" val DESCRIPTION_OP_TIME_LEGACY = "op time (legacy)" val DESCRIPTION_OP_TIME_NEW = "op time" - val DESCRIPTION_OP_TIME_NEW_SHUFFLE_WRITE = "op time (shuffle write partition & serial)" + val DESCRIPTION_OP_TIME_NEW_SHUFFLE_WRITE = "op time (shuffle write partitioning)" val DESCRIPTION_OP_TIME_NEW_SHUFFLE_READ = "op time (shuffle read)" val DESCRIPTION_COLLECT_TIME = "collect batch time" val DESCRIPTION_CONCAT_TIME = "concat batch time" diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala index d5193578583..8086d75c253 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala @@ -42,6 +42,16 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L private val isUnlimited = nonPinnedLimit < 0 private val isPinnedOnly = nonPinnedLimit == 0 + // Expose for usage ratio calculation + def getCurrentAllocated: Long = synchronized { + currentNonPinnedAllocated + currentPinnedAllocated + } + + def getTotalLimit: Long = { + if (isUnlimited) Long.MaxValue + else pinnedLimit + nonPinnedLimit + } + /** * A callback class so we know when a non-pinned host buffer was released */ @@ -297,6 +307,28 @@ object HostAlloc extends Logging { getSingleton.alloc(amount, preferPinned) } + /** + * Get current host memory usage ratio (0.0 to 1.0). + * Returns current allocated / limit. + */ + def getUsageRatio(): Double = { + val alloc = getSingleton + val currentAllocated = alloc.getCurrentAllocated + val totalLimit = alloc.getTotalLimit + if (totalLimit == Long.MaxValue) { + 0.0 // Unlimited, consider as 0% used + } else { + currentAllocated.toDouble / totalLimit.toDouble + } + } + + /** + * Check if host memory usage is below the given threshold (0.0 to 1.0). + */ + def isUsageBelowThreshold(threshold: Double): Boolean = { + getUsageRatio() < threshold + } + def addEventHandler(buff: HostMemoryBuffer, handler: MemoryBuffer.EventHandler): HostMemoryBuffer = { buff.synchronized { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 1cfdb1643cd..4f9320ee899 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -561,6 +561,50 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .bytesConf(ByteUnit.BYTE) .createWithDefault(-1) + val PARTIAL_FILE_BUFFER_INITIAL_SIZE = + conf("spark.rapids.memory.host.partialFileBufferInitialSize") + .doc("The initial size in bytes for a host memory buffer used by " + + "SpillablePartialFileHandle during shuffle write. This buffer allows shuffle " + + "data to be kept in memory instead of writing to disk immediately, reducing " + + "I/O overhead. The buffer can expand dynamically up to partialFileBufferMaxSize. " + + "A smaller initial size reduces upfront memory allocation but may require more " + + "expansions.") + .startupOnly() + .internal() + .bytesConf(ByteUnit.BYTE) + .createWithDefault(128L * 1024 * 1024) // 128MB default + + val PARTIAL_FILE_BUFFER_MAX_SIZE = + conf("spark.rapids.memory.host.partialFileBufferMaxSize") + .doc("The maximum size in bytes for a single host memory buffer used by " + + "SpillablePartialFileHandle during shuffle write. When a buffer needs to " + + "expand beyond this limit, it will be spilled to disk instead. This prevents " + + "excessive memory usage for large shuffle partitions. Note: Due to ByteBuffer " + + "constraints, the effective maximum is Int.MaxValue (~2GB).") + .startupOnly() + .internal() + .bytesConf(ByteUnit.BYTE) + .createWithDefault(Int.MaxValue.toLong) // ~2GB, limited by ByteBuffer + + val PARTIAL_FILE_BUFFER_MEMORY_THRESHOLD = + conf("spark.rapids.memory.host.partialFileBufferMemoryThreshold") + .doc("The host memory usage threshold (as a fraction from 0.0 to 1.0) for deciding " + + "whether to use memory-based buffering for partial files during shuffle write. " + + "When host memory usage exceeds this threshold, file-based storage will be used " + + "directly. This threshold also applies when expanding buffers dynamically. " + + "Setting this too high may cause threads holding the GPU semaphore to block on " + + "spilling, which wastes valuable GPU resources. Setting it too low reduces the " + + "shuffle write optimization benefit. A value around 0.5-0.6 typically provides " + + "optimal performance. As a guideline, ensure that (1 - threshold) * total_host_mem " + + "is greater than num_threads * gpu_batch_size to leave enough memory for other " + + "threads to operate without forcing spills.") + .startupOnly() + .internal() + .doubleConf + .checkValue(v => v > 0.0 && v <= 1.0, + "The memory threshold must be in the range (0.0, 1.0]") + .createWithDefault(0.5) + val UNSPILL = conf("spark.rapids.memory.gpu.unspill.enabled") .doc("When a spilled GPU buffer is needed again, should it be unspilled, or only copied " + "back into GPU memory temporarily. Unspilling may be useful for GPU buffers that are " + @@ -3317,6 +3361,12 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val hostSpillStorageSize: Long = get(HOST_SPILL_STORAGE_SIZE) + lazy val partialFileBufferInitialSize: Long = get(PARTIAL_FILE_BUFFER_INITIAL_SIZE) + + lazy val partialFileBufferMaxSize: Long = get(PARTIAL_FILE_BUFFER_MAX_SIZE) + + lazy val partialFileBufferMemoryThreshold: Double = get(PARTIAL_FILE_BUFFER_MEMORY_THRESHOLD) + lazy val isUnspillEnabled: Boolean = get(UNSPILL) lazy val needDecimalGuarantees: Boolean = get(NEED_DECIMAL_OVERFLOW_GUARANTEES) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 8ef16a471d4..4fee08ffcb9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -1893,7 +1893,7 @@ object SpillFramework extends Logging { // public for tests. Some tests not in the `spill` package require setting this // because they need fine control over allocations. var storesInternal: SpillableStores = _ - + def stores: SpillableStores = { if (storesInternal == null) { throw new IllegalStateException( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala new file mode 100644 index 00000000000..d35fdcb1205 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandle.scala @@ -0,0 +1,647 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.spill + +import java.io.{BufferedInputStream, BufferedOutputStream, File, FileInputStream, FileOutputStream, IOException} + +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.HostAlloc + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.rapids.GpuTaskMetrics + +/** + * Storage mode for SpillablePartialFileHandle. + */ +object PartialFileStorageMode extends Enumeration { + val FILE_ONLY, MEMORY_WITH_SPILL = Value +} + +/** + * A specialized spillable handle for partial files that provides unified write/read + * interfaces for both file-based and memory-based (with spill support) storage. + * + * This handle is designed for scenarios where: + * 1. When memory is scarce (usage > threshold), use file-based storage directly + * 2. When memory is sufficient, use host memory buffer with automatic spill support + * + * Features: + * - Unified write/read interface regardless of storage mode + * - Protection from spill during write phase + * - Sequential read support to avoid frequent stream open/close + * - Automatic transition from memory to disk when spilled + * - Dynamic buffer expansion when capacity is exceeded (up to configured max limit) + * - Automatic fallback to file when expansion fails or conditions not met + * + * @param storageMode Whether to use FILE_ONLY or MEMORY_WITH_SPILL + * @param file File to use for FILE_ONLY mode or as spill target for MEMORY_WITH_SPILL + * @param initialCapacity Initial capacity for buffer allocation (MEMORY_WITH_SPILL only) + * @param maxBufferSize Maximum buffer size before spilling to disk + * @param memoryThreshold Host memory usage threshold for buffer expansion decisions + * @param priority Spill priority for memory-based mode + * @param syncWrites Whether to force outstanding writes to disk + */ +class SpillablePartialFileHandle private ( + storageMode: PartialFileStorageMode.Value, + file: File, + initialCapacity: Long, + maxBufferSize: Long, + memoryThreshold: Double, + priority: Long, + syncWrites: Boolean) + extends HostSpillableHandle[ai.rapids.cudf.HostMemoryBuffer] with Logging { + + // State management + @volatile private var spilledToDisk: Boolean = false + override private[spill] var host: Option[ai.rapids.cudf.HostMemoryBuffer] = None + override val approxSizeInBytes: Long = initialCapacity + + // Track current buffer capacity (can grow via expansion) + private var currentBufferCapacity: Long = initialCapacity + + // Protect from spill during write phase + private var protectedFromSpill: Boolean = true + private var writeFinished: Boolean = false + + // Write state + private var writePosition: Long = 0L + private var fileOutputStream: Option[FileOutputStream] = None + private var bufferedOutputStream: Option[BufferedOutputStream] = None + + // Read state + private var readPosition: Long = 0L + private var fileInputStream: Option[FileInputStream] = None + private var bufferedInputStream: Option[BufferedInputStream] = None + private var totalBytesWritten: Long = 0L + + // Initialize host buffer for MEMORY_WITH_SPILL mode + if (storageMode == PartialFileStorageMode.MEMORY_WITH_SPILL) { + try { + val buffer = ai.rapids.cudf.HostMemoryBuffer.allocate(initialCapacity, false) + host = Some(buffer) + currentBufferCapacity = initialCapacity + this.taskPriority = priority + SpillFramework.stores.hostStore.trackNoSpill(this) + } catch { + case e: Exception => + logWarning(s"Failed to allocate initial buffer of $initialCapacity bytes, " + + s"falling back to file-based storage", e) + // Fallback to file-based if allocation fails + spilledToDisk = true + currentBufferCapacity = 0L + } + } + + /** + * Check if we should use file for IO (either FILE_ONLY mode or spilled). + */ + private def shouldUseFile: Boolean = { + storageMode == PartialFileStorageMode.FILE_ONLY || spilledToDisk + } + + /** + * Expand host buffer capacity to meet required capacity. + * Tries to double the capacity until reaching required size, + * falls back to file-based if expansion fails. + * + * Conditions checked before expansion: + * 1. New capacity does not exceed configured max buffer size limit + * 2. Current memory usage is below configured threshold + * + * @param requiredCapacity The minimum capacity needed + * @return true if successfully expanded, false if spilled to file instead + */ + private def expandBuffer(requiredCapacity: Long): Boolean = { + host match { + case Some(currentBuffer) => + val oldCapacity = currentBufferCapacity + + // Calculate new capacity: keep doubling until >= requiredCapacity + var newCapacity = oldCapacity + while (newCapacity < requiredCapacity && newCapacity < maxBufferSize) { + newCapacity = math.min(newCapacity * 2, maxBufferSize) + } + + // Check if new capacity is still insufficient after expansion + if (newCapacity < requiredCapacity) { + logDebug(s"Buffer expansion cannot meet required capacity " + + s"(need $requiredCapacity bytes, max limit is $maxBufferSize bytes), " + + s"spilling to disk") + spillBufferToFileAndSwitch(currentBuffer) + return false + } + + // Check if new capacity exceeds limit (should not happen due to math.min) + if (newCapacity > maxBufferSize) { + logDebug(s"Buffer expansion would exceed configured limit " + + s"(need $newCapacity bytes, limit is $maxBufferSize bytes), spilling to disk") + spillBufferToFileAndSwitch(currentBuffer) + return false + } + + // Check for Int.MaxValue limit due to ByteBuffer constraints + if (newCapacity > Int.MaxValue) { + logDebug(s"Buffer expansion would exceed ByteBuffer limit (Int.MaxValue) " + + s"required by buffer.asByteBuffer() used during spill " + + s"(need $newCapacity bytes, limit is ${Int.MaxValue} bytes), spilling to disk") + spillBufferToFileAndSwitch(currentBuffer) + return false + } + + // Check if memory usage is still below threshold + if (!HostAlloc.isUsageBelowThreshold(memoryThreshold)) { + logDebug(s"Memory usage above ${memoryThreshold * 100}% threshold, " + + s"spilling to disk instead of expanding") + spillBufferToFileAndSwitch(currentBuffer) + return false + } + + try { + // Allocate new larger buffer + val newBuffer = ai.rapids.cudf.HostMemoryBuffer.allocate(newCapacity, false) + closeOnExcept(newBuffer) { _ => + // Copy existing data + newBuffer.copyFromHostBuffer(0, currentBuffer, 0, writePosition) + + // Remove old buffer tracking and track new one + SpillFramework.removeFromHostStore(this) + currentBuffer.close() + host = Some(newBuffer) + currentBufferCapacity = newCapacity + SpillFramework.stores.hostStore.trackNoSpill(this) + + logDebug(s"Expanded buffer from $oldCapacity to $newCapacity bytes " + + s"(required $requiredCapacity bytes)") + } + true + } catch { + case e: Exception => + logDebug(s"Failed to allocate buffer of $newCapacity bytes, " + + s"spilling to disk", e) + spillBufferToFileAndSwitch(currentBuffer) + false + } + case None => + throw new IllegalStateException("Host buffer is null") + } + } + + /** + * Spill current buffer content to file and switch to file-based mode. + * Called when buffer expansion fails or capacity cannot grow further. + */ + private def spillBufferToFileAndSwitch( + buffer: ai.rapids.cudf.HostMemoryBuffer): Unit = { + // Defensive check: writePosition should not exceed Int.MaxValue + // because expandBuffer() limits buffer size to Int.MaxValue + require(writePosition <= Int.MaxValue, + s"Cannot spill buffer larger than Int.MaxValue: $writePosition bytes") + + // Write current buffer content to file + withResource(new FileOutputStream(file)) { fos => + val channel = fos.getChannel + val bb = buffer.asByteBuffer() + bb.limit(writePosition.toInt) + while (bb.hasRemaining) { + channel.write(bb) + } + if (syncWrites) { + channel.force(true) + } + } + + // Release buffer and switch to file mode + SpillFramework.removeFromHostStore(this) + buffer.close() + host = None + spilledToDisk = true + + logDebug(s"Spilled buffer to ${file.getAbsolutePath} during write " + + s"($writePosition bytes), continuing write to file") + } + + /** + * Write a single byte to the partial file. + * No synchronization needed: write phase is protected from spilling. + */ + def write(b: Int): Unit = { + if (writeFinished) { + throw new IllegalStateException("Write phase already finished") + } + + if (shouldUseFile) { + // FILE_ONLY mode or spilled: write to file + ensureFileOutputStreamOpen() + bufferedOutputStream.get.write(b) + writePosition += 1 + } else { + // MEMORY_WITH_SPILL mode: write to buffer (protected from spill) + host match { + case Some(_) => + // Check if buffer needs expansion + val requiredCapacity = writePosition + 1 + if (requiredCapacity > currentBufferCapacity) { + val expanded = expandBuffer(requiredCapacity) + // After expansion, may have spilled to file, recursively call write + if (!expanded) { + // Spilled to file, retry write (will go to file branch) + write(b) + return + } + } + // Write to buffer (may be new buffer after expansion) + host.get.setByte(writePosition, b.toByte) + writePosition += 1 + case None => + throw new IllegalStateException("Host buffer is null") + } + } + } + + /** + * Write bytes to the partial file. + * No synchronization needed: write phase is protected from spilling. + */ + def write(bytes: Array[Byte], offset: Int, length: Int): Unit = { + if (writeFinished) { + throw new IllegalStateException("Write phase already finished") + } + + if (shouldUseFile) { + // FILE_ONLY mode or spilled: write to file + ensureFileOutputStreamOpen() + bufferedOutputStream.get.write(bytes, offset, length) + writePosition += length + } else { + // MEMORY_WITH_SPILL mode: write to buffer (protected from spill) + host match { + case Some(_) => + // Check if buffer needs expansion + val requiredCapacity = writePosition + length + if (requiredCapacity > currentBufferCapacity) { + logDebug(s"Buffer expansion needed: writePos=$writePosition, length=$length, " + + s"required=$requiredCapacity, current=$currentBufferCapacity") + val expanded = expandBuffer(requiredCapacity) + // After expansion, may have spilled to file, recursively call write + if (!expanded) { + // Spilled to file, retry write (will go to file branch) + write(bytes, offset, length) + return + } + logDebug(s"After expansion: currentCapacity=$currentBufferCapacity, " + + s"bufferLength=${host.get.getLength}") + } + // Write to buffer (may be new buffer after expansion) + host.get.setBytes(writePosition, bytes, offset, length) + writePosition += length + case None => + throw new IllegalStateException("Host buffer is null") + } + } + } + + /** + * Finish write phase and enable spilling. + * After this call, no more writes are allowed but reads can proceed. + * + * This is where we record disk write savings metric: if this handle is in + * MEMORY_WITH_SPILL mode and hasn't spilled yet, it means we successfully + * avoided disk writes during the write phase. + */ + def finishWrite(): Unit = { + // Extract streams under lock, close them outside + val (bos, fos, shouldRecordSavings) = synchronized { + if (writeFinished) { + return + } + + writeFinished = true + totalBytesWritten = writePosition + protectedFromSpill = false + + // Check if we should record disk write savings: + // 1. Must be MEMORY_WITH_SPILL mode (not FILE_ONLY) + // 2. Must not have spilled yet (data still in memory) + // This means we successfully avoided disk writes during write phase + val recordSavings = storageMode == PartialFileStorageMode.MEMORY_WITH_SPILL && + !spilledToDisk && totalBytesWritten > 0 + + val b = bufferedOutputStream + val f = fileOutputStream + bufferedOutputStream = None + fileOutputStream = None + (b, f, recordSavings) + } + + // Close streams outside lock (IO operations can be slow) + bos.foreach { s => + s.flush() + s.close() + } + fos.foreach(_.close()) + + // Record disk write savings if applicable + if (shouldRecordSavings) { + SpillablePartialFileHandle.recordDiskWriteSaved(totalBytesWritten) + logDebug(s"Recorded disk write savings: $totalBytesWritten bytes " + + s"(kept in memory during write phase)") + } + } + + /** + * Read bytes from the partial file sequentially. + * Returns number of bytes actually read, or -1 if EOF. + * + * Note: This method is NOT thread-safe. Concurrent reads from multiple threads + * are not supported. This class is designed for single-threaded sequential reads + * in the shuffle merge phase (see RapidsShuffleInternalManagerBase.mergePartialFiles). + * + * Internal synchronization only protects against concurrent spill operations, + * not concurrent read operations. + */ + def read(bytes: Array[Byte], offset: Int, length: Int): Int = { + if (!writeFinished) { + throw new IllegalStateException("Cannot read before write is finished") + } + + if (readPosition >= totalBytesWritten) { + return -1 // EOF + } + + val actualLength = math.min(length, (totalBytesWritten - readPosition).toInt) + + def readFromFile(bytes: Array[Byte], offset: Int, length: Int): Int = { + ensureFileInputStreamOpen() + val bytesRead = bufferedInputStream.get.read(bytes, offset, length) + if (bytesRead > 0) { + readPosition += bytesRead + } + bytesRead + } + + if (shouldUseFile) { + // File-based: no spill can happen, no synchronization needed + readFromFile(bytes, offset, actualLength) + } else { + // Memory-based: check volatile flag without lock + if (spilledToDisk) { + // Spilled after our first check, no lock needed now + readFromFile(bytes, offset, actualLength) + } else { + // Still in memory, need lock for the read operation + synchronized { + // Double-check: may have spilled between our checks + if (spilledToDisk) { + // Just spilled, release lock and read from file + // Note: We exit synchronized block here + } else { + // Confirmed still in memory, read with lock held + host match { + case Some(buffer) => + buffer.getBytes(bytes, offset, readPosition, actualLength) + readPosition += actualLength + return actualLength + case None => + throw new IllegalStateException("Host buffer is null") + } + } + } + // If we reach here, it means spilled during double-check + readFromFile(bytes, offset, actualLength) + } + } + } + + /** + * Get total bytes written to this partial file. + */ + def getTotalBytesWritten: Long = totalBytesWritten + + /** + * Check if this handle is using MEMORY_WITH_SPILL mode. + */ + def isMemoryBased: Boolean = storageMode == PartialFileStorageMode.MEMORY_WITH_SPILL + + /** + * Check if memory-based data has been spilled to disk. + * Always returns false for FILE_ONLY mode. + */ + def isSpilled: Boolean = spilledToDisk + + /** + * Override spillable to add write phase protection and actual state checks. + * Since approxSizeInBytes is now a fixed val, we need to check actual state here. + */ + override private[spill] def spillable: Boolean = synchronized { + super.spillable && !protectedFromSpill && !spilledToDisk && host.nonEmpty + } + + /** + * Spill memory buffer to disk. + */ + override def spill(): Long = synchronized { + if (storageMode != PartialFileStorageMode.MEMORY_WITH_SPILL) { + return 0L // Nothing to spill for FILE_ONLY mode + } + + if (!writeFinished) { + // This should not happen because protectedFromSpill prevents spill during write + logWarning("Attempted to spill during write phase, which should be protected") + return 0L + } + + host match { + case Some(buffer) => + // Defensive check: totalBytesWritten should not exceed Int.MaxValue + // because expandBuffer() limits buffer size to Int.MaxValue + require(totalBytesWritten <= Int.MaxValue, + s"Cannot spill buffer larger than Int.MaxValue: $totalBytesWritten bytes") + + // Spill all written data to file + val fos = new FileOutputStream(file) + try { + val channel = fos.getChannel + val bb = buffer.asByteBuffer() + bb.limit(totalBytesWritten.toInt) + while (bb.hasRemaining) { + channel.write(bb) + } + if (syncWrites) { + channel.force(true) + } + } finally { + fos.close() + } + + spilledToDisk = true + SpillFramework.removeFromHostStore(this) + buffer.close() + host = None + + logDebug(s"Spilled to ${file.getAbsolutePath} " + + s"($totalBytesWritten bytes)") + + totalBytesWritten + + case None => + 0L // Already spilled + } + } + + /** + * Ensure file output stream is open for writing. + * Thread-safe: uses synchronized to prevent duplicate stream creation. + */ + private def ensureFileOutputStreamOpen(): Unit = synchronized { + if (fileOutputStream.isEmpty) { + val fos = new FileOutputStream(file, true) // append mode + fileOutputStream = Some(fos) + bufferedOutputStream = Some(new BufferedOutputStream(fos, 64 * 1024)) + } + } + + /** + * Ensure file input stream is open for reading. + * Thread-safe: uses synchronized to prevent duplicate stream creation. + */ + private def ensureFileInputStreamOpen(): Unit = synchronized { + if (fileInputStream.isEmpty) { + val fis = new FileInputStream(file) + // Skip to current read position + if (readPosition > 0) { + var remaining = readPosition + while (remaining > 0) { + val skipped = fis.skip(remaining) + if (skipped <= 0) { + throw new IOException(s"Failed to skip to position $readPosition") + } + remaining -= skipped + } + } + fileInputStream = Some(fis) + bufferedInputStream = Some(new BufferedInputStream(fis, 64 * 1024)) + } + } + + /** + * Close and cleanup resources. + */ + override private[spill] def doClose(): Unit = synchronized { + // Close output streams + bufferedOutputStream.foreach { bos => + try { bos.close() } catch { case _: Exception => } + } + bufferedOutputStream = None + fileOutputStream.foreach { fos => + try { fos.close() } catch { case _: Exception => } + } + fileOutputStream = None + + // Close input streams + bufferedInputStream.foreach { bis => + try { bis.close() } catch { case _: Exception => } + } + bufferedInputStream = None + fileInputStream.foreach { fis => + try { fis.close() } catch { case _: Exception => } + } + fileInputStream = None + + // Release host buffer (removes from SpillFramework tracking and closes buffer) + releaseHostResource() + + // Delete file if it exists + if (file != null && file.exists()) { + try { + file.delete() + } catch { + case _: Exception => // Ignore + } + } + } +} + +object SpillablePartialFileHandle extends Logging { + + /** + * Record disk write savings for a SpillablePartialFileHandle. + * Should be called when a handle successfully avoided disk writes during write phase. + * + * This tracks bytes that were kept in memory during shuffle write phase, + * avoiding disk writes compared to the baseline implementation. + * + * @param bytesSaved Number of bytes that avoided disk write + */ + private[spill] def recordDiskWriteSaved(bytesSaved: Long): Unit = { + if (bytesSaved > 0) { + GpuTaskMetrics.get.addDiskWriteSaved(bytesSaved) + logDebug(s"Recorded disk write savings: $bytesSaved bytes " + + s"(kept in memory during write phase)") + } + } + + /** + * Create a file-only handle. + * Data is written directly to disk without using host memory. + * + * @param file File to write data to + * @param syncWrites Whether to force outstanding writes to disk + */ + def createFileOnly(file: File, syncWrites: Boolean = false): + SpillablePartialFileHandle = { + new SpillablePartialFileHandle( + storageMode = PartialFileStorageMode.FILE_ONLY, + file = file, + initialCapacity = 0L, + maxBufferSize = 0L, + memoryThreshold = 0.0, + priority = Long.MinValue, + syncWrites = syncWrites) + } + + /** + * Create a memory-with-spill handle. + * Data is initially written to host memory buffer and can be spilled to disk + * if needed. The buffer will automatically expand when needed (up to + * maxBufferSize limit). + * + * @param initialCapacity Initial size of host memory buffer to allocate + * @param maxBufferSize Maximum buffer size before spilling to disk + * @param memoryThreshold Host memory usage threshold for buffer expansion + * decisions + * @param spillFile File to use when spilling is required + * @param priority Spill priority + * @param syncWrites Whether to force outstanding writes to disk + */ + def createMemoryWithSpill( + initialCapacity: Long, + maxBufferSize: Long, + memoryThreshold: Double, + spillFile: File, + priority: Long = Long.MinValue, + syncWrites: Boolean = false): SpillablePartialFileHandle = { + new SpillablePartialFileHandle( + storageMode = PartialFileStorageMode.MEMORY_WITH_SPILL, + file = spillFile, + initialCapacity = initialCapacity, + maxBufferSize = maxBufferSize, + memoryThreshold = memoryThreshold, + priority = priority, + syncWrites = syncWrites) + } +} + diff --git a/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleDataIO.scala b/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleDataIO.scala new file mode 100644 index 00000000000..89ab9257a9a --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleDataIO.scala @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io + +import org.apache.spark.SparkConf +import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents} + +/** + * RAPIDS-optimized implementation of ShuffleDataIO that uses host memory buffers + * for partial sorted files when possible, with automatic spill to disk support. + */ +class RapidsLocalDiskShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { + + override def executor(): ShuffleExecutorComponents = { + new RapidsLocalDiskShuffleExecutorComponents(sparkConf) + } + + override def driver(): ShuffleDriverComponents = { + new LocalDiskShuffleDriverComponents() + } +} + diff --git a/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleExecutorComponents.scala b/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleExecutorComponents.scala new file mode 100644 index 00000000000..e923d57ab83 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleExecutorComponents.scala @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io + +import java.util.{Map => JMap, Optional} + +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.shuffle.api.{ShuffleExecutorComponents, ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter} +import org.apache.spark.storage.BlockManager + +/** + * RAPIDS-optimized executor components that creates RapidsLocalDiskShuffleMapOutputWriter + * instances with host memory buffer support. + */ +class RapidsLocalDiskShuffleExecutorComponents(sparkConf: SparkConf) + extends ShuffleExecutorComponents with Logging { + + private var blockManager: BlockManager = null + private var blockResolver: IndexShuffleBlockResolver = null + + @VisibleForTesting + def this( + sparkConf: SparkConf, + blockManager: BlockManager, + blockResolver: IndexShuffleBlockResolver) = { + this(sparkConf) + this.blockManager = blockManager + this.blockResolver = blockResolver + } + + override def initializeExecutor( + appId: String, + execId: String, + extraConfigs: JMap[String, String]): Unit = { + blockManager = SparkEnv.get.blockManager + if (blockManager == null) { + throw new IllegalStateException("No blockManager available from the SparkEnv.") + } + blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager) + } + + override def createMapOutputWriter( + shuffleId: Int, + mapTaskId: Long, + numPartitions: Int): ShuffleMapOutputWriter = { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers.") + } + + new RapidsLocalDiskShuffleMapOutputWriter( + shuffleId, + mapTaskId, + numPartitions, + blockResolver, + sparkConf) + } + + override def createSingleFileMapOutputWriter( + shuffleId: Int, + mapId: Long): Optional[SingleSpillShuffleMapOutputWriter] = { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers.") + } + Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)) + } +} + diff --git a/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleMapOutputWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleMapOutputWriter.scala new file mode 100644 index 00000000000..95a225589d9 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/shuffle/sort/io/RapidsLocalDiskShuffleMapOutputWriter.scala @@ -0,0 +1,304 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io + +import java.io.{File, IOException, OutputStream} +import java.nio.ByteBuffer +import java.nio.channels.WritableByteChannel +import java.util.Optional + +import com.nvidia.spark.rapids.{HostAlloc, RapidsConf} +import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter, WritableByteChannelWrapper} +import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage +import org.apache.spark.util.Utils + +/** + * RAPIDS-optimized ShuffleMapOutputWriter that writes to host memory first + * (if sufficient memory is available), then spills to disk if needed. + * Particularly useful for partial sorted files in multi-batch scenarios. + */ +class RapidsLocalDiskShuffleMapOutputWriter( + shuffleId: Int, + mapId: Long, + numPartitions: Int, + blockResolver: IndexShuffleBlockResolver, + sparkConf: SparkConf) + extends ShuffleMapOutputWriter with Logging { + + private val partitionLengths = new Array[Long](numPartitions) + private var lastPartitionId = -1 + private var currChannelPosition = 0L + private var bytesWrittenToMergedFile = 0L + + private val outputFile = blockResolver.getDataFile(shuffleId, mapId) + private var outputTempFile: File = null + + // RAPIDS configuration + private val rapidsConf = new RapidsConf(sparkConf) + private val initialBufferSize = rapidsConf.partialFileBufferInitialSize + private val maxBufferSize = rapidsConf.partialFileBufferMaxSize + private val memoryThreshold = rapidsConf.partialFileBufferMemoryThreshold + + // Read Spark's shuffle sync configuration to maintain compatibility + private val syncWrites = sparkConf.get("spark.shuffle.sync", "false").toBoolean + + // RAPIDS optimization: use SpillablePartialFileHandle for unified storage + private var partialFileHandle: Option[SpillablePartialFileHandle] = None + private var storageInitAttempted: Boolean = false + private var forceFileOnly: Boolean = false + + /** + * Force this writer to use file-only mode, bypassing memory-based buffering. + * This is useful for scenarios where memory buffering is not beneficial, + * such as final merge operations. + * + * This method must be called before any partition writer is requested. + */ + def setForceFileOnlyMode(): Unit = { + if (storageInitAttempted) { + throw new IllegalStateException( + "Cannot set force file-only mode after storage has been initialized. " + + "Storage was initialized when getPartitionWriter() was called. " + + "Call setForceFileOnlyMode() before requesting any partition writers.") + } + forceFileOnly = true + } + + // Try to initialize storage on first partition write + private def ensureStorageInitialized(): Unit = { + if (!storageInitAttempted) { + storageInitAttempted = true + outputTempFile = Utils.tempFileWith(outputFile) + + // Check if file-only mode is forced + if (forceFileOnly) { + // Force file-only mode (e.g., for final merge operations) + logDebug(s"Using forced file-only mode for shuffle $shuffleId map $mapId") + val handle = SpillablePartialFileHandle.createFileOnly( + outputTempFile, syncWrites) + partialFileHandle = Some(handle) + } else if (HostAlloc.isUsageBelowThreshold(memoryThreshold)) { + // Memory sufficient: use MEMORY_WITH_SPILL mode + try { + val handle = SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = initialBufferSize, + maxBufferSize = maxBufferSize, + memoryThreshold = memoryThreshold, + spillFile = outputTempFile, + priority = Long.MinValue, + syncWrites = syncWrites) + partialFileHandle = Some(handle) + logDebug(s"Using memory-with-spill mode for shuffle $shuffleId map $mapId " + + s"(initial=${initialBufferSize / 1024 / 1024}MB, " + + s"max=${maxBufferSize / 1024 / 1024}MB)") + } catch { + case e: Exception => + logWarning(s"Failed to create memory buffer, " + + s"falling back to file-only", e) + val handle = SpillablePartialFileHandle.createFileOnly( + outputTempFile, syncWrites) + partialFileHandle = Some(handle) + } + } else { + // Memory scarce: use FILE_ONLY mode + logDebug(s"Host memory usage high, using file-only mode for shuffle " + + s"$shuffleId map $mapId") + val handle = SpillablePartialFileHandle.createFileOnly( + outputTempFile, syncWrites) + partialFileHandle = Some(handle) + } + } + } + + override def getPartitionWriter(reducePartitionId: Int): ShufflePartitionWriter = { + if (reducePartitionId <= lastPartitionId) { + throw new IllegalArgumentException( + "Partitions should be requested in increasing order.") + } + lastPartitionId = reducePartitionId + + // Initialize storage on first partition + ensureStorageInitialized() + + // Record current position for partition length calculation + currChannelPosition = partialFileHandle.map(_.getTotalBytesWritten).getOrElse(0L) + new RapidsPartitionWriter(reducePartitionId) + } + + override def commitAllPartitions(checksums: Array[Long]): MapOutputCommitMessage = { + // Finish write phase to enable spilling and finalize data + partialFileHandle.foreach { handle => + handle.finishWrite() + + // If memory-based, not spilled yet, and has data, force spill to create file + // writeMetadataFileAndCommit requires a valid file + // Skip spill for empty shuffle to avoid creating unnecessary empty files + if (handle.isMemoryBased && !handle.isSpilled && handle.getTotalBytesWritten > 0) { + handle.spill() + } + } + + val resolvedTmp = if (outputTempFile != null && outputTempFile.isFile) { + outputTempFile + } else { + null + } + + logDebug(s"Writing shuffle index file for mapId $mapId with length " + + s"${partitionLengths.length}") + blockResolver.writeMetadataFileAndCommit( + shuffleId, mapId, partitionLengths, checksums, resolvedTmp) + + // Close the partial file handle to release any remaining resources + // (e.g., host buffer if spill() was not called due to empty partitions) + partialFileHandle.foreach(_.close()) + partialFileHandle = None + + MapOutputCommitMessage.of(partitionLengths) + } + + override def abort(error: Throwable): Unit = { + partialFileHandle.foreach(_.close()) + partialFileHandle = None + if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) { + logWarning(s"Failed to delete temporary shuffle file at " + + s"${outputTempFile.getAbsolutePath}") + } + } + + /** + * Get the partial file handle for accessing data. + */ + def getPartialFileHandle(): Option[SpillablePartialFileHandle] = partialFileHandle + + /** + * Get partition lengths array directly (for extracting without reflection). + */ + def getPartitionLengths(): Array[Long] = partitionLengths + + /** + * Finish write phase to finalize data (called before extraction). + */ + def finishWritePhase(): Unit = { + partialFileHandle.foreach(_.finishWrite()) + } + + private class RapidsPartitionWriter(partitionId: Int) extends ShufflePartitionWriter { + private var partStream: OutputStream = null + private var partChannel: WritableByteChannelWrapper = null + + override def openStream(): OutputStream = { + if (partStream == null) { + partStream = new PartitionWriterStream(partitionId) + } + partStream + } + + override def openChannelWrapper(): Optional[WritableByteChannelWrapper] = { + if (partChannel == null) { + partChannel = new PartitionWriterChannel(partitionId) + } + Optional.of(partChannel) + } + + override def getNumBytesWritten(): Long = { + if (partChannel != null) { + partChannel.asInstanceOf[PartitionWriterChannel].getCount + } else if (partStream != null) { + partStream.asInstanceOf[PartitionWriterStream].getCount() + } else { + 0L + } + } + } + + // Unified stream writer using SpillablePartialFileHandle + private class PartitionWriterStream(partitionId: Int) extends OutputStream { + private var count = 0L + private var isClosed = false + + def getCount(): Long = count + + override def write(b: Int): Unit = { + verifyNotClosed() + partialFileHandle.foreach(_.write(b)) + count += 1 + } + + override def write(buf: Array[Byte], pos: Int, length: Int): Unit = { + verifyNotClosed() + partialFileHandle.foreach(_.write(buf, pos, length)) + count += length + } + + override def close(): Unit = { + isClosed = true + partitionLengths(partitionId) = count + bytesWrittenToMergedFile += count + } + + private def verifyNotClosed(): Unit = { + if (isClosed) { + throw new IllegalStateException( + "Attempting to write to a closed block output stream.") + } + } + } + + // Unified channel writer using SpillablePartialFileHandle + private class PartitionWriterChannel(partitionId: Int) + extends WritableByteChannelWrapper { + + private val startPosition = currChannelPosition + + def getCount: Long = { + partialFileHandle.map(_.getTotalBytesWritten).getOrElse(0L) - startPosition + } + + override def channel(): WritableByteChannel = new WritableByteChannel { + private var channelOpen = true + + override def write(src: ByteBuffer): Int = { + if (!channelOpen) { + throw new IOException("Channel is closed") + } + val remaining = src.remaining() + val temp = new Array[Byte](remaining) + src.get(temp) + partialFileHandle.foreach(_.write(temp, 0, remaining)) + remaining + } + + override def isOpen: Boolean = channelOpen + + override def close(): Unit = { + channelOpen = false + } + } + + override def close(): Unit = { + partitionLengths(partitionId) = getCount + bytesWrittenToMergedFile += partitionLengths(partitionId) + } + } +} + diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala index 547c2af5b66..e9b20b90785 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala @@ -278,6 +278,9 @@ class GpuTaskMetrics extends Serializable with Logging { private val maxGpuFootprint = new SizeInBytesAccumulator + // Disk write savings from SpillablePartialFileHandle + private val diskWriteSavedBytes = new LongAccumulator + private var maxHostBytesAllocated: Long = 0 private var maxPageableBytesAllocated: Long = 0 private var maxPinnedBytesAllocated: Long = 0 @@ -339,7 +342,8 @@ class GpuTaskMetrics extends Serializable with Logging { "gpuOnGpuTasksWaitingGPUMaxCount" -> onGpuTasksInWaitingQueueMaxCount, "gpuMaxTaskFootprint" -> maxGpuFootprint, "multithreadReaderMaxParallelism" -> multithreadReaderMaxParallelism, - "gpuMaxConcurrentGpuTasks" -> maxConcurrentGpuTasks + "gpuMaxConcurrentGpuTasks" -> maxConcurrentGpuTasks, + "gpuDiskWriteSavedBytes" -> diskWriteSavedBytes ) def register(sc: SparkContext): Unit = { @@ -489,6 +493,10 @@ class GpuTaskMetrics extends Serializable with Logging { def updateMultithreadReaderMaxParallelism(parallelism: Long): Unit = { multithreadReaderMaxParallelism.add(parallelism) } + + def addDiskWriteSaved(bytes: Long): Unit = { + diskWriteSavedBytes.add(bytes) + } } /** diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index f8d9091634e..c8b8211d785 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -16,13 +16,13 @@ package org.apache.spark.sql.rapids -import java.io.{File, FileInputStream} -import java.util.Optional -import java.util.concurrent.{Callable, ConcurrentHashMap, ExecutionException, Executors, Future, LinkedBlockingQueue, TimeUnit} -import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} +import java.io.{IOException, OutputStream} +import java.util.concurrent.{Callable, ConcurrentHashMap, CopyOnWriteArrayList, ExecutionException, Executors, Future, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong} +import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ListBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.withResource @@ -31,7 +31,9 @@ import com.nvidia.spark.rapids.RapidsConf import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.jni.kudo.OpenByteArrayOutputStream import com.nvidia.spark.rapids.shuffle.{RapidsShuffleRequestHandler, RapidsShuffleServer, RapidsShuffleTransport} +import com.nvidia.spark.rapids.spill.SpillablePartialFileHandle import org.apache.spark.{InterruptibleIterator, MapOutputTracker, ShuffleDependency, SparkConf, SparkEnv, TaskContext} import org.apache.spark.executor.ShuffleWriteMetrics @@ -41,13 +43,14 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.{ShuffleWriter, _} import org.apache.spark.shuffle.api._ -import org.apache.spark.shuffle.sort.{BypassMergeSortShuffleHandle, SortShuffleManager} +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.sort.io.{RapidsLocalDiskShuffleDataIO, RapidsLocalDiskShuffleMapOutputWriter} import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBase.{METRIC_DATA_READ_SIZE, METRIC_DATA_SIZE, METRIC_SHUFFLE_COMBINE_TIME, METRIC_SHUFFLE_DESERIALIZATION_TIME, METRIC_SHUFFLE_READ_TIME, METRIC_SHUFFLE_SERIALIZATION_TIME, METRIC_SHUFFLE_WRITE_IO_TIME, METRIC_SHUFFLE_WRITE_TIME} +import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBase.{METRIC_DATA_READ_SIZE, METRIC_DATA_SIZE, METRIC_SHUFFLE_DESERIALIZATION_TIME, METRIC_SHUFFLE_READ_TIME, METRIC_THREADED_WRITER_LIMITER_WAIT_TIME, METRIC_THREADED_WRITER_SERIALIZATION_WAIT_TIME} import org.apache.spark.sql.rapids.shims.{GpuShuffleBlockResolver, RapidsShuffleThreadedReader, RapidsShuffleThreadedWriter} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.{RapidsShuffleBlockFetcherIterator, _} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.{ExternalSorter, OpenHashSet} class GpuShuffleHandle[K, V]( @@ -152,12 +155,15 @@ object RapidsShuffleInternalManagerBase extends Logging { // spark.rapids.shuffle.multiThreaded.reader.threads private var numWriterSlots: Int = 0 private var numReaderSlots: Int = 0 + private var numMergerSlots: Int = 0 private lazy val writerSlots = new mutable.HashMap[Int, Slot]() private lazy val readerSlots = new mutable.HashMap[Int, Slot]() + private lazy val mergerSlots = new mutable.HashMap[Int, Slot]() // used by callers to obtain a unique slot private val writerSlotNumber = new AtomicInteger(0) private val readerSlotNumber= new AtomicInteger(0) + private val mergerSlotNumber = new AtomicInteger(0) private var mtShuffleInitialized: Boolean = false @@ -183,6 +189,17 @@ object RapidsShuffleInternalManagerBase extends Logging { readerSlots(slotNum % numReaderSlots).offer(task) } + /** + * Send a task to a specific merger slot. + * @param slotNum the slot to submit to + * @param task a task to execute + * @note there must not be an uncaught exception while calling + * `task`. + */ + def queueMergerTask[T](slotNum: Int, task: Callable[T]): Future[T] = { + mergerSlots(slotNum % numMergerSlots).offer(task) + } + def startThreadPoolIfNeeded( numWriterThreads: Int, numReaderThreads: Int): Unit = synchronized { @@ -190,6 +207,8 @@ object RapidsShuffleInternalManagerBase extends Logging { mtShuffleInitialized = true numWriterSlots = numWriterThreads numReaderSlots = numReaderThreads + // Use same number of merger slots as writer slots + numMergerSlots = numWriterThreads if (writerSlots.isEmpty) { (0 until numWriterSlots).foreach { slotNum => writerSlots.put(slotNum, new Slot(slotNum, "writer")) @@ -200,6 +219,11 @@ object RapidsShuffleInternalManagerBase extends Logging { readerSlots.put(slotNum, new Slot(slotNum, "reader")) } } + if (mergerSlots.isEmpty) { + (0 until numMergerSlots).foreach { slotNum => + mergerSlots.put(slotNum, new Slot(slotNum, "merger")) + } + } } } @@ -210,10 +234,19 @@ object RapidsShuffleInternalManagerBase extends Logging { readerSlots.values.foreach(_.shutdownNow()) readerSlots.clear() + + mergerSlots.values.foreach(_.shutdownNow()) + mergerSlots.clear() + + // Reset slot counters to ensure clean state for next initialization + writerSlotNumber.set(0) + readerSlotNumber.set(0) + mergerSlotNumber.set(0) } def getNextWriterSlot: Int = Math.abs(writerSlotNumber.incrementAndGet()) def getNextReaderSlot: Int = Math.abs(readerSlotNumber.incrementAndGet()) + def getNextMergerSlot: Int = Math.abs(mergerSlotNumber.incrementAndGet()) } trait RapidsShuffleWriterShimHelper { @@ -246,303 +279,693 @@ abstract class RapidsShuffleThreadedWriterBase[K, V]( numWriterThreads: Int) extends RapidsShuffleWriter[K, V] with RapidsShuffleWriterShimHelper { - private val metrics = handle.metrics - private val serializationTimeMetric = - metrics.get(METRIC_SHUFFLE_SERIALIZATION_TIME) - private val shuffleWriteTimeMetric = - metrics.get(METRIC_SHUFFLE_WRITE_TIME) - private val shuffleCombineTimeMetric = - metrics.get(METRIC_SHUFFLE_COMBINE_TIME) - private val ioTimeMetric = - metrics.get(METRIC_SHUFFLE_WRITE_IO_TIME) private val dep: ShuffleDependency[K, V, V] = handle.dependency private val shuffleId = dep.shuffleId private val partitioner = dep.partitioner private val numPartitions = partitioner.numPartitions private val serializer = dep.serializer.newInstance() - private val transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true) private val fileBufferSize = sparkConf.get(config.SHUFFLE_FILE_BUFFER_SIZE).toInt * 1024 private val limiter = new BytesInFlightLimiter(maxBytesInFlight) + private val limiterWaitTimeMetric = + handle.metrics.get(METRIC_THREADED_WRITER_LIMITER_WAIT_TIME) + private val serializationWaitTimeMetric = + handle.metrics.get(METRIC_THREADED_WRITER_SERIALIZATION_WAIT_TIME) private var shuffleWriteRange: NvtxId = NvtxRegistry.THREADED_WRITER_WRITE.push() + // Case class for tracking partial sorted files in multi-batch scenario + private case class PartialFile( + handle: SpillablePartialFileHandle, + partitionLengths: Array[Long], + mapOutputWriter: ShuffleMapOutputWriter) + /** - * Simple wrapper that tracks the time spent iterating the given iterator. + * Encapsulates all state for processing one GPU batch in the multi-batch shuffle write. + * + * In multi-batch mode, each GPU batch gets its own BatchState with independent buffers, + * futures, and a dedicated merger thread. This enables pipeline parallelism where: + * - Main thread: processes records and queues compression tasks (non-blocking) + * - Writer threads: execute compression tasks in parallel + * - Merger thread: waits for completed compressions and writes partitions sequentially + * + * The merger thread writes partitions in order (0, 1, 2, ...) because Spark's + * ShuffleMapOutputWriter requires sequential partition writes. + * + * @param batchId Unique identifier for this batch (for debugging/logging) + * @param mapOutputWriter Shuffle output writer for this batch. When using + * RapidsLocalDiskShuffleMapOutputWriter, data may be buffered + * in memory first (via SpillablePartialFileHandle) and only + * written to disk on spill or commit. For fallback writers, + * data is written directly to temp files. + * @param partitionBuffers Maps partitionId -> compressed data buffer. Compression tasks + * append data here; merger thread reads and writes to disk. + * @param partitionFutures Maps partitionId -> list of compression task futures. + * Each future returns (uncompressedSize, compressedSize). + * One future per record, so multiple futures when partition has + * multiple records. + * @param partitionWrittenBytes Maps partitionId -> bytes already written to output stream. + * Used for incremental writes as compression tasks complete. + * @param partitionProcessedFutures Maps partitionId -> count of futures already processed. + * Merger thread skips already-processed futures. + * @param maxPartitionIdQueued Highest partition ID that main thread has queued tasks for. + * Merger thread uses this to know when a partition is complete: + * if currentPartition < maxPartitionIdQueued, all data for + * currentPartition has been queued. This typically happens at + * batch boundaries when partition ID wraps back to a lower value. + * @param mergerCondition Condition variable for main thread to wake up merger thread + * when new compression tasks are queued or batch is complete. + * @param mergerSlotNum The merger thread pool slot assigned to this batch. + * @param mergerFuture Future representing the merger task, used to wait for completion. */ - private class TimeTrackingIterator(delegate: Iterator[Product2[K, V]]) - extends Iterator[Product2[K, V]] { + private case class BatchState( + batchId: Int, + mapOutputWriter: ShuffleMapOutputWriter, + partitionBuffers: ConcurrentHashMap[Int, OpenByteArrayOutputStream], + partitionFutures: ConcurrentHashMap[Int, + CopyOnWriteArrayList[Future[(Long, Long)]]], + partitionWrittenBytes: ConcurrentHashMap[Int, Long], + partitionProcessedFutures: ConcurrentHashMap[Int, Int], + maxPartitionIdQueued: AtomicInteger, + mergerCondition: Object, + // Flag for classic wait/notify pattern: set to true when new work is available, + // reset to false after merger thread wakes up and checks actual data state. + // This avoids busy-loop polling and provides clear signal for debugging. + hasNewWork: AtomicBoolean, + mergerSlotNum: Int, + mergerFuture: Future[_]) - private var iterateTimeNs: Long = 0L + /** + * Increment the reference count and get the memory size for a value. + * This method handles ColumnarBatch values with SlicedGpuColumnVector or + * SlicedSerializedColumnVector columns. + * + * @param value the value to process (typically a ColumnarBatch) + * @return a tuple of (ColumnarBatch with incremented ref count, memory size) + * @throws IllegalStateException if value is not a ColumnarBatch or contains + * unsupported column types + */ + private def incRefCountAndGetSize(value: Any): (ColumnarBatch, Long) = { + value match { + case columnarBatch: ColumnarBatch => + if (columnarBatch.numCols() > 0) { + columnarBatch.column(0) match { + case _: SlicedGpuColumnVector => + (SlicedGpuColumnVector.incRefCount(columnarBatch), + SlicedGpuColumnVector.getTotalHostMemoryUsed(columnarBatch)) + case _: SlicedSerializedColumnVector => + (SlicedSerializedColumnVector.incRefCount(columnarBatch), + SlicedSerializedColumnVector.getTotalHostMemoryUsed( + columnarBatch)) + case other => + throw new IllegalStateException( + s"Unexpected column type in ColumnarBatch: ${other.getClass.getName}. " + + "Expected SlicedGpuColumnVector or SlicedSerializedColumnVector.") + } + } else { + (columnarBatch, 0L) + } + case other => + throw new IllegalStateException( + s"Unexpected value type: ${if (other == null) "null" else other.getClass.getName}. " + + "Expected ColumnarBatch.") + } + } - override def hasNext: Boolean = { - val start = System.nanoTime() - val ret = delegate.hasNext - iterateTimeNs += System.nanoTime() - start - ret + /** + * Create independent state for processing one GPU batch. + * This allows multiple batches to be processed in pipeline without blocking. + */ + private def createBatchState( + batchId: Int, + writer: ShuffleMapOutputWriter): BatchState = { + + val partitionBuffers = new ConcurrentHashMap[Int, OpenByteArrayOutputStream]() + val partitionFutures = new ConcurrentHashMap[Int, + CopyOnWriteArrayList[Future[(Long, Long)]]]() + val partitionWrittenBytes = new ConcurrentHashMap[Int, Long]() + val partitionProcessedFutures = new ConcurrentHashMap[Int, Int]() + + // Synchronization strategy for maxPartitionIdQueued and mergerCondition: + // + // maxPartitionIdQueued: Tracks the highest partition ID queued by main thread. + // - Main thread: updates via set() after adding futures, synchronized with + // maxPartitionIdQueued to ensure atomic update with futures.add() + // - Merger thread: reads via get() to check if current partition is complete + // (currentPartition < maxPartitionIdQueued means all data for currentPartition + // has been queued) + // + // mergerCondition: Condition variable for merger thread to wait on. + // - Main thread: sets hasNewWork=true and calls notifyAll() after queuing new tasks + // - Merger thread: uses classic flag pattern (while !hasNewWork wait()) to avoid + // busy-loop polling and provide clear debugging signal + val maxPartitionIdQueued = new AtomicInteger(-1) + val mergerCondition = new Object() + val hasNewWork = new AtomicBoolean(false) + + // Assign a merger slot for this batch + val mergerSlotNum = RapidsShuffleInternalManagerBase.getNextMergerSlot + + var unfinishedStream: Option[OutputStream] = None + + // Helper to write the buffer for a single partition. + // Buffer lifecycle: + // - doCleanUp=false: incremental write, buffer stays open for more data from same partition + // - doCleanUp=true: final write, closes buffer and streams (called when partition is complete) + // Normal path: merger calls with doCleanUp=false for each compression task, then calls + // with doCleanUp=true when containsLastForThisPartition=true (all data for partition queued). + // Exception path: buffers are closed in the finally block of writePartitionedGpuBatches. + def writeBufferForSinglePartition( + partitionId: Int, + start: Long, + end: Long, + doCleanUp: Boolean): Unit = { + Option(partitionBuffers.get(partitionId)) match { + case Some(buffer) => + if (unfinishedStream.isEmpty) { + unfinishedStream = Some(writer.getPartitionWriter(partitionId).openStream()) + } + if (end - start > 0) { + unfinishedStream.get.write(buffer.getBuf, start.toInt, (end - start).toInt) + } + if (doCleanUp) { + buffer.close() + partitionBuffers.remove(partitionId) + unfinishedStream.get.close() + unfinishedStream = None + partitionFutures.remove(partitionId) + partitionProcessedFutures.remove(partitionId) + partitionWrittenBytes.remove(partitionId) + } + case None => + throw new IllegalStateException( + s"No buffer found for partition $partitionId in batch $batchId") + } } - override def next(): Product2[K, V] = { - val start = System.nanoTime() - val ret = delegate.next - iterateTimeNs += System.nanoTime() - start - ret + // Merger task for this batch + val mergerTask = new Runnable { + override def run(): Unit = { + var currentPartitionToWrite = 0 + // Check for thread interruption to allow graceful shutdown + while (currentPartitionToWrite < numPartitions && !Thread.currentThread().isInterrupted) { + if (currentPartitionToWrite <= maxPartitionIdQueued.get()) { + var containsLastForThisPartition = false + var futures: CopyOnWriteArrayList[Future[(Long, Long)]] = null + + maxPartitionIdQueued.synchronized { + futures = partitionFutures.get(currentPartitionToWrite) + if (currentPartitionToWrite < maxPartitionIdQueued.get()) { + containsLastForThisPartition = true + } + } + + if (futures != null) { + // Track if any new future was processed in this iteration + var newFutureTouched = false + val processedCount = + partitionProcessedFutures.getOrDefault(currentPartitionToWrite, 0) + // Process only futures that haven't been processed yet + futures.asScala.zipWithIndex.filter(pair => { + pair._2 >= processedCount + }).foreach { future => + newFutureTouched = true + // remainingQuota is the compressedSize that was held after Writer released + // the excess quota (recordSize - compressedSize) + val (remainingQuota, compressedSize) = future._1.get() + + // Write newly compressed data incrementally + val writtenBytes = + partitionWrittenBytes.getOrDefault(currentPartitionToWrite, 0L) + writeBufferForSinglePartition(currentPartitionToWrite, + writtenBytes, + writtenBytes + compressedSize, + doCleanUp = false) + + partitionWrittenBytes.put( + currentPartitionToWrite, writtenBytes + compressedSize) + partitionProcessedFutures.compute(currentPartitionToWrite, + (key, value) => { value + 1 }) + + // Release the remaining quota after data is written to output stream. + // The limiter controls heap memory pressure from deserialization and + // OpenByteArrayOutputStream. Once data is written to output stream (which may + // buffer in SpillablePartialFileHandle or write to disk), the heap buffer + // becomes eligible for GC. Releasing quota here allows more compression + // tasks to proceed without causing heap OOM. + limiter.release(remainingQuota) + } + + if (containsLastForThisPartition) { + writeBufferForSinglePartition(currentPartitionToWrite, 0, 0, doCleanUp = true) + currentPartitionToWrite += 1 + } else { + if (!newFutureTouched) { + // No new futures were processed in this iteration, wait for main thread + // to queue more compression tasks. Use classic condition flag pattern + // to avoid busy-loop polling and provide clear debugging signal. + mergerCondition.synchronized { + while (!hasNewWork.get()) { + mergerCondition.wait() + } + hasNewWork.set(false) + } + } + } + } else { + val partWriter = writer.getPartitionWriter(currentPartitionToWrite) + partWriter.openStream().close() + currentPartitionToWrite += 1 + } + } else { + // Current partition hasn't been queued yet by main thread, wait for it. + // Use classic condition flag pattern to avoid busy-loop polling and + // provide clear debugging signal. + mergerCondition.synchronized { + while (!hasNewWork.get()) { + mergerCondition.wait() + } + hasNewWork.set(false) + } + } + } + } } - def getIterateTimeNs: Long = iterateTimeNs - } + val mergerFuture = RapidsShuffleInternalManagerBase.queueMergerTask( + mergerSlotNum, () => { + mergerTask.run() + null + }) - override def write(records: Iterator[Product2[K, V]]): Unit = { - // Iterating the `records` may involve some heavy computations. - // TimeTrackingIterator is used to track how much time we spend for such computations. - write(new TimeTrackingIterator(records)) + BatchState( + batchId, + writer, + partitionBuffers, + partitionFutures, + partitionWrittenBytes, + partitionProcessedFutures, + maxPartitionIdQueued, + mergerCondition, + hasNewWork, + mergerSlotNum, + mergerFuture) } - private def write(records: TimeTrackingIterator): Unit = { - // Timestamp when the main processing begins - val processingStart: Long = System.nanoTime() + override def write(records: Iterator[Product2[K, V]]): Unit = { val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( shuffleId, mapId, numPartitions) + mapOutputWriters += mapOutputWriter // Track for cleanup + + val partLengths = if (!records.hasNext) { + commitAllPartitions(mapOutputWriter, true) + } else { + writePartitionedGpuBatches(records, mapOutputWriter) + } + + myMapStatus = Some(getMapStatus(blockManager.shuffleServerId, partLengths, mapId)) + + if (shuffleWriteRange != null) { + shuffleWriteRange.pop() + shuffleWriteRange = null + } + } + + /** + * Unified write path that handles both single batch and multi-batch tasks. + * Uses streaming parallel processing with pipelined partition writing. + * + * Data flow for each record: + * 1. ColumnarBatch (already copied to host memory, may be split from GPU batches based on + * spark.rapids.shuffle.partitioning.maxCpuBatchSize) -> Main thread acquires limiter quota + * 2. Writer thread: serialize + compress -> OpenByteArrayOutputStream (JVM heap) + * 3. Writer thread: release excess quota (recordSize - compressedSize) + * 4. Merger thread: heap buffer -> ShuffleMapOutputWriter (via SpillablePartialFileHandle) + * - If MEMORY_WITH_SPILL mode: data may stay in host memory until spill/commit + * - If FILE_ONLY mode or spilled: data goes to disk + * 5. Merger thread: release remaining quota after writing to output stream + * 6. (Multi-batch only) Main thread: mergePartialFiles() combines all batch outputs into + * final shuffle file, reading from each SpillablePartialFileHandle sequentially + * + * Threading model (same for both scenarios): + * - Main thread: Processes all records without blocking, queues compression tasks + * - Background merger thread(s): Wait for compression tasks to complete and write + * partitions to disk in order + * - Worker threads: Execute compression tasks in parallel + * + * Single batch: One merger thread writes directly to final output file + * + * Multi-batch: Detects partition ID decreasing (indicates new batch), creates + * independent state for each batch (each with its own merger thread running in parallel), + * then merges all batch outputs into final file. + */ + private def writePartitionedGpuBatches( + records: Iterator[Product2[Any, Any]], + mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = { + + val serializerInstance = serializer + var recordsWritten: Long = 0L + + // Track timing for metrics + val writeStartTime = System.nanoTime() + // Track total written size (compressed size) + val totalCompressedSize = new AtomicLong(0L) + var waitTimeOnLimiterNs: Long = 0L + + // Multi-batch tracking + val batchStates = new ArrayBuffer[BatchState]() + val partialFiles = new ArrayBuffer[PartialFile]() + var currentBatchId: Int = 0 + var previousMaxPartition: Int = -1 + var isMultiBatch: Boolean = false + + // Maps partitionId -> writer slot number. Ensures all compression tasks for the same + // partition run serially in the same single-threaded slot, preventing concurrent writes + // to the same partition buffer. Different partitions can still run in parallel. + val partitionSlots = new ConcurrentHashMap[Int, Int]() + + // Create initial batch state + var currentBatch = createBatchState(currentBatchId, mapOutputWriter) + try { - var openTimeNs = 0L - val partLengths = if (!records.hasNext) { - commitAllPartitions(mapOutputWriter, true /*empty checksum*/) - } else { - // per reduce partition id - // open all the writers ahead of time (Spark does this already) - val openStartTime = System.nanoTime() - (0 until numPartitions).map { i => - val (blockId, file) = blockManager.diskBlockManager.createTempShuffleBlock() - val writer: DiskBlockObjectWriter = blockManager.getDiskWriter( - blockId, file, serializer, fileBufferSize, writeMetrics) - setChecksumIfNeeded(writer, i) // spark3.2.0+ - - // Places writer objects at round robin slot numbers apriori - // this choice is for simplicity but likely needs to change so that - // we can handle skew better - val slotNum = RapidsShuffleInternalManagerBase.getNextWriterSlot - diskBlockObjectWriters.put(i, (slotNum, writer)) + while (records.hasNext) { + val record = records.next() + val key = record._1 + val value = record._2 + val reducePartitionId: Int = partitioner.getPartition(key) + + // Detect multi-batch: partition ID must be strictly increasing within a batch. + // If current partition ID < previous max, it means we've jumped back to an earlier + // partition, indicating a new upstream GPU batch. Note: we use < instead of <= because + // consecutive identical partition IDs can occur in two scenarios: + // 1. Reslicing: when a partition's data exceeds maxCpuBatchSize + // 2. Data skew: multiple GPU batches each containing only the same partition's data + // In both cases, merging them into a single shuffle batch is correct and more efficient + // (fewer partial files, less merge overhead). + if (reducePartitionId < previousMaxPartition) { + if (!isMultiBatch) { + isMultiBatch = true + logDebug(s"Detected multi-batch scenario for shuffle $shuffleId, " + + s"transitioning to pipeline mode") + } + + // Signal current batch is complete by setting maxPartitionIdQueued to numPartitions. + // This tells the merger thread that all partitions (0 to numPartitions-1) have been + // queued, so it can finish writing remaining partitions without waiting. + // We notify the merger thread in case it's waiting for more work. + // Note: We don't block here - the merger runs in parallel while we start next batch. + currentBatch.maxPartitionIdQueued.set(numPartitions) + currentBatch.mergerCondition.synchronized { + currentBatch.hasNewWork.set(true) + currentBatch.mergerCondition.notifyAll() + } + + // Add to list for later finalization + batchStates += currentBatch + + // Immediately create new batch and continue processing (pipeline!) + currentBatchId += 1 + val newWriter = shuffleExecutorComponents.createMapOutputWriter( + shuffleId, + mapId, + numPartitions) + mapOutputWriters += newWriter // Track for cleanup + currentBatch = createBatchState(currentBatchId, newWriter) + + // Reset to -1 for new batch. This ensures the first record of the new batch + // (with any valid partition ID >= 0) won't trigger another batch switch, + // since reducePartitionId > -1 will always be true. + previousMaxPartition = -1 } - openTimeNs = System.nanoTime() - openStartTime - // we call write on every writer for every record in parallel - val writeFutures = new mutable.Queue[Future[Unit]] - // Accumulated record write time as if they were sequential - val recordWriteTime: AtomicLong = new AtomicLong(0L) - // Time spent waiting on the limiter - var waitTimeOnLimiterNs: Long = 0L - // Time spent computing ColumnarBatch sizes - var batchSizeComputeTimeNs: Long = 0L + recordsWritten += 1 + previousMaxPartition = math.max(previousMaxPartition, reducePartitionId) - try { - while (records.hasNext) { - // get the record - val record = records.next() - val key = record._1 - val value = record._2 - val reducePartitionId: Int = partitioner.getPartition(key) - val (slotNum, myWriter) = diskBlockObjectWriters(reducePartitionId) - - if (numWriterThreads == 1) { - val recordWriteTimeStart = System.nanoTime() - myWriter.write(key, value) - recordWriteTime.getAndAdd(System.nanoTime() - recordWriteTimeStart) - } else { - // we close batches actively in the `records` iterator as we get the next batch - // this makes sure it is kept alive while a task is able to handle it. - val sizeComputeStart = System.nanoTime() - val (cb, size) = value match { - case columnarBatch: ColumnarBatch => - if (columnarBatch.numCols() > 0) { - columnarBatch.column(0) match { - case _: SlicedGpuColumnVector => - (SlicedGpuColumnVector.incRefCount(columnarBatch), - SlicedGpuColumnVector.getTotalHostMemoryUsed(columnarBatch)) - case _: SlicedSerializedColumnVector => - (SlicedSerializedColumnVector.incRefCount(columnarBatch), - SlicedSerializedColumnVector.getTotalHostMemoryUsed(columnarBatch)) - case _ => - (null, 0L) - } - } else { - (columnarBatch, 0L) - } - case _ => - (null, 0L) + // Get or create futures queue for this partition in current batch + val futures = currentBatch.partitionFutures.computeIfAbsent(reducePartitionId, + _ => new CopyOnWriteArrayList[Future[(Long, Long)]]()) + + val (cb, recordSize) = incRefCountAndGetSize(value) + + // Acquire limiter and process compression task immediately + val waitOnLimiterStart = System.nanoTime() + limiter.acquireOrBlock(recordSize) + waitTimeOnLimiterNs += System.nanoTime() - waitOnLimiterStart + + // Get or assign a slot number for this partition to ensure + // all tasks for the same partition run serially in the same slot + val slotNum = partitionSlots.computeIfAbsent(reducePartitionId, + _ => RapidsShuffleInternalManagerBase.getNextWriterSlot) + val currentBatchFinalRef = currentBatch + val future = RapidsShuffleInternalManagerBase.queueWriteTask(slotNum, () => { + try { + withResource(cb) { _ => + // Get or create buffer for this partition in current batch + val buffer = currentBatchFinalRef.partitionBuffers.computeIfAbsent( + reducePartitionId, _ => new OpenByteArrayOutputStream()) + val originLength = buffer.getCount + + // Serialize + compress + encryption to memory buffer + val compressedOutputStream = blockManager.serializerManager.wrapStream( + ShuffleBlockId(shuffleId, mapId, reducePartitionId), buffer) + + val serializationStream = serializerInstance.serializeStream( + compressedOutputStream) + withResource(serializationStream) { serializer => + serializer.writeKey(key.asInstanceOf[Any]) + serializer.writeValue(value.asInstanceOf[Any]) } - val waitOnLimiterStart = System.nanoTime() - batchSizeComputeTimeNs += waitOnLimiterStart - sizeComputeStart - limiter.acquireOrBlock(size) - waitTimeOnLimiterNs += System.nanoTime() - waitOnLimiterStart - writeFutures += RapidsShuffleInternalManagerBase.queueWriteTask(slotNum, () => { - withResource(cb) { _ => - try { - val recordWriteTimeStart = System.nanoTime() - myWriter.write(key, value) - recordWriteTime.getAndAdd(System.nanoTime() - recordWriteTimeStart) - } finally { - limiter.release(size) - } - } - }) - } - } - } finally { - // This is in a finally block so that if there is an exception queueing - // futures, that we will have waited for any queued write future before we call - // .abort on the map output writer (we had test failures otherwise) - NvtxRegistry.WAITING_FOR_WRITES { - try { - while (writeFutures.nonEmpty) { - try { - writeFutures.dequeue().get() - } catch { - case ee: ExecutionException => - // this exception is a wrapper for the underlying exception - // i.e. `IOException`. The ShuffleWriter.write interface says - // it can throw these. - throw ee.getCause - } + + // Track total written data size (compressed size) + val compressedSize = (buffer.getCount - originLength).toLong + totalCompressedSize.addAndGet(compressedSize) + + // Release excess quota immediately after compression. + // Data is now in OpenByteArrayOutputStream (heap), only need to hold + // compressedSize quota until Merger writes to disk. + + + // Note: excessQuota can be 0 if compression doesn't reduce size (or expands) + val excessQuota = math.max(0L, recordSize - compressedSize) + if (excessQuota > 0) { + limiter.release(excessQuota) } - } finally { - // cancel all pending futures (only in case of error will we cancel) - writeFutures.foreach(_.cancel(true /*ok to interrupt*/)) + + // Return the quota that Merger should release later + // Total released = excessQuota + remainingQuota should equal recordSize + val remainingQuota = recordSize - excessQuota + (remainingQuota, compressedSize) } + } catch { + case e: Exception => + throw new IOException( + s"Failed compression task for shuffle $shuffleId, map $mapId, " + + s"partition $reducePartitionId", e) } + }) + + currentBatch.maxPartitionIdQueued.synchronized { + futures.add(future) + currentBatch.maxPartitionIdQueued.set( + math.max(currentBatch.maxPartitionIdQueued.get(), reducePartitionId)) } - // writeTimeNs is an approximation of the amount of time we spent in - // DiskBlockObjectWriter.write, which involves serializing records and writing them - // on disk. As we use multiple threads for writing, writeTimeNs is - // estimated by 'the total amount of time it took to finish processing the entire logic - // above' minus 'the amount of time it took to do anything expensive other than the - // serialization and the write. The latter involves computations in upstream execs, - // ColumnarBatch size estimation, and the time blocked on the limiter. - val writeTimeNs = (System.nanoTime() - processingStart) - - records.getIterateTimeNs - batchSizeComputeTimeNs - waitTimeOnLimiterNs - - val combineTimeStart = System.nanoTime() - val pl = writePartitionedData(mapOutputWriter) - val combineTimeNs = System.nanoTime() - combineTimeStart - - // add openTime which is also done by Spark, and we are counting - // in the ioTime later - writeMetrics.incWriteTime(openTimeNs) - - // At this point, Spark has timed the amount of time it took to write - // to disk (the IO, per write). But note that when we look at the - // multi threaded case, this metric is now no longer task-time. - // Users need to look at "rs. shuffle write time" (shuffleWriteTimeMetric), - // which does its own calculation at the task-thread level. - // We use ioTimeNs, however, to get an approximation of serialization time. - val ioTimeNs = - writeMetrics.asInstanceOf[ThreadSafeShuffleWriteMetricsReporter].getWriteTime - - // serializationTime is the time spent compressing/encoding batches that wasn't - // counted in the ioTime - val totalPerRecordWriteTime = recordWriteTime.get() + ioTimeNs - val ioRatio = (ioTimeNs.toDouble/totalPerRecordWriteTime) - val serializationRatio = 1.0 - ioRatio - - // update metrics, note that we expect them to be relative to the task - ioTimeMetric.foreach(_ += (ioRatio * writeTimeNs).toLong) - serializationTimeMetric.foreach(_ += (serializationRatio * writeTimeNs).toLong) - // we add all three here because this metric is meant to show the time - // we are blocked on writes - shuffleWriteTimeMetric.foreach(_ += (writeTimeNs + combineTimeNs)) - shuffleCombineTimeMetric.foreach(_ += combineTimeNs) - pl + // Wake up merger thread to process newly queued compression task. + // This enables pipeline parallelism: main thread continues to next record + // while merger thread processes completed compressions in parallel. + currentBatch.mergerCondition.synchronized { + currentBatch.hasNewWork.set(true) + currentBatch.mergerCondition.notifyAll() + } } - myMapStatus = Some(getMapStatus(blockManager.shuffleServerId, partLengths, mapId)) - } catch { - // taken directly from BypassMergeSortShuffleWriter - case e: Exception => + + // Mark end of last batch by setting maxPartitionIdQueued to numPartitions. + // This signals the merger thread that all partitions have been queued. + // Notify ensures merger wakes up to finish any remaining work. + currentBatch.maxPartitionIdQueued.set(numPartitions) + currentBatch.mergerCondition.synchronized { + currentBatch.hasNewWork.set(true) + currentBatch.mergerCondition.notifyAll() + } + + // Add last batch to list + batchStates += currentBatch + + // Wait for all batches to complete (now they can finish in parallel!) + var totalSerializationWaitTimeNs: Long = 0L + batchStates.foreach { batch => try { - mapOutputWriter.abort(e) + val waitStart = System.nanoTime() + batch.mergerFuture.get() + totalSerializationWaitTimeNs += System.nanoTime() - waitStart } catch { - case e2: Exception => - logError("Failed to abort the writer after failing to write map output.", e2); - e.addSuppressed(e2); + case ee: ExecutionException => throw ee.getCause } - throw e - } - if (shuffleWriteRange != null) { - shuffleWriteRange.pop() - shuffleWriteRange = null - } - } - - def writePartitionedData(mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = { - // after all temporary shuffle writes are done, we need to produce a single - // file (shuffle_[map_id]_0) which is done during this commit phase - NvtxRegistry.COMMIT_SHUFFLE { - // per reduce partition - val segments = (0 until numPartitions).map { - reducePartitionId => - withResource(diskBlockObjectWriters(reducePartitionId)._2) { writer => - val segment = writer.commitAndGet() - (reducePartitionId, segment.file) - } + // CRITICAL: For multi-batch, preserve handle before any commit + // commitAllPartitions() would flush/rename data, so we extract first + if (isMultiBatch) { + val (handle, partLengths) = extractHandleAndLengthsFromWriter( + batch.mapOutputWriter) + partialFiles += PartialFile(handle, partLengths, batch.mapOutputWriter) + } else { + // Single batch: commit normally + commitAllPartitions(batch.mapOutputWriter, true) + } } - val writeStartTime = System.nanoTime() - segments.foreach { case (reducePartitionId, file) => - val partWriter = mapOutputWriter.getPartitionWriter(reducePartitionId) - if (file.exists()) { - if (transferToEnabled) { - val maybeOutputChannel: Optional[WritableByteChannelWrapper] = - partWriter.openChannelWrapper() - if (maybeOutputChannel.isPresent) { - writePartitionedDataWithChannel(file, maybeOutputChannel.get()) - } else { - writePartitionedDataWithStream(file, partWriter) - } - } else { - writePartitionedDataWithStream(file, partWriter) + // Update write metrics + val totalWriteTime = System.nanoTime() - writeStartTime + writeMetrics.incWriteTime(totalWriteTime - waitTimeOnLimiterNs) + writeMetrics.incRecordsWritten(recordsWritten) + writeMetrics.incBytesWritten(totalCompressedSize.get()) + limiterWaitTimeMetric.foreach(_ += waitTimeOnLimiterNs) + serializationWaitTimeMetric.foreach(_ += totalSerializationWaitTimeNs) + + } finally { + // Cleanup all batch states + batchStates.foreach { batch => + // Cancel writer future if still running + batch.mergerFuture.cancel(true) + + // Cancel pending futures + batch.partitionFutures.values().asScala.foreach { futuresQueue => + futuresQueue.asScala.foreach(_.cancel(true)) + futuresQueue.clear() + } + + // Clean up buffers + val iter = batch.partitionBuffers.values().iterator() + while (iter.hasNext()) { + try { + iter.next().close() + } catch { + case e: Exception => + logWarning(s"Failed to close partition buffer during cleanup", e) } - file.delete() } } - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime) } - commitAllPartitions(mapOutputWriter, false /*non-empty checksums*/) - } - // taken from BypassMergeSortShuffleWriter - // this code originally called into guava.Closeables.close - // and had logic to silence exceptions thrown while copying - // I am ignoring this for now. - def writePartitionedDataWithStream(file: java.io.File, writer: ShufflePartitionWriter): Unit = { - withResource(new FileInputStream(file)) { in => - withResource(writer.openStream()) { os => - Utils.copyStream(in, os, false, false) + // Handle final output + if (isMultiBatch) { + // Multi-batch: create NEW writer for final merge + // CRITICAL: Cannot reuse mapOutputWriter as it would write to same outputTempFile + val finalMergeWriter = shuffleExecutorComponents.createMapOutputWriter( + shuffleId, + mapId, + numPartitions) + mapOutputWriters += finalMergeWriter // Track for cleanup + + // Force file-only mode for final merge writer since it doesn't benefit + // from memory buffering (merge operation is already doing sequential I/O) + finalMergeWriter match { + case rapidsWriter: RapidsLocalDiskShuffleMapOutputWriter => + rapidsWriter.setForceFileOnlyMode() + case _ => // Other writer types don't need this optimization } + + mergePartialFiles(partialFiles.toSeq, finalMergeWriter) + } else { + // Single batch: already committed, just return lengths + getPartitionLengths } } - // taken from BypassMergeSortShuffleWriter - // this code originally called into guava.Closeables.close - // and had logic to silence exceptions thrown while copying - // I am ignoring this for now. - def writePartitionedDataWithChannel( - file: File, - outputChannel: WritableByteChannelWrapper): Unit = { - // note outputChannel.close() doesn't actually close it. - // The call is there to record keep the partition lengths - // after the serialization completes. - withResource(outputChannel) { _ => - withResource(new FileInputStream(file)) { in => - withResource(in.getChannel) { inputChannel => - Utils.copyFileStreamNIO( - inputChannel, outputChannel.channel, 0L, inputChannel.size) + /** + * Merge multiple partial sorted files into final output. + * Each partial file contains data for all partitions (0 to N) from one GPU batch. + * The merged file will have: partition 0 from all batches, partition 1 from all batches, etc. + * + * Layout of merged file: + * partition 0 data from partial file 0 + * partition 0 data from partial file 1 + * ... + * partition 0 data from partial file M + * partition 1 data from partial file 0 + * partition 1 data from partial file 1 + * ... + */ + private def mergePartialFiles( + partialFiles: Seq[PartialFile], + finalWriter: ShuffleMapOutputWriter): Array[Long] = { + + val mergeStartTime = System.nanoTime() + + try { + // For each partition, copy data from all partial files in order + // Note: Each partial file is read sequentially from beginning to end, + // so no need to reset read position between partitions + (0 until numPartitions).foreach { partitionId => + val partWriter = finalWriter.getPartitionWriter(partitionId) + + withResource(partWriter.openStream()) { os => + partialFiles.foreach { partialFile => + val partitionLength = partialFile.partitionLengths(partitionId) + if (partitionLength > 0) { + val handle = partialFile.handle + + // Read partition data sequentially + // No reset needed - handle maintains read position automatically + val temp = new Array[Byte](fileBufferSize) + var remaining = partitionLength + while (remaining > 0) { + val bytesToRead = math.min(remaining, temp.length).toInt + val bytesRead = handle.read(temp, 0, bytesToRead) + if (bytesRead > 0) { + os.write(temp, 0, bytesRead) + remaining -= bytesRead + } else { + throw new IOException( + s"EOF reading partition $partitionId " + + s"from partial file ${partialFiles.indexOf(partialFile)}, " + + s"expected $partitionLength bytes, got ${partitionLength - remaining}") + } + } + } + } + } + } + } finally { + // Cleanup partial file handles + partialFiles.foreach { pf => + try { + pf.handle.close() + } catch { + case e: Exception => + logWarning(s"Failed to close partial file handle during cleanup", e) } } } - } + writeMetrics.incWriteTime(System.nanoTime() - mergeStartTime) + // Commit final merged output + commitAllPartitions(finalWriter, true) + } + + /** + * Extract partial file handle and partitionLengths from ShuffleMapOutputWriter. + * Since we always use RapidsLocalDiskShuffleMapOutputWriter, this is straightforward. + */ + private def extractHandleAndLengthsFromWriter(writer: ShuffleMapOutputWriter): + (SpillablePartialFileHandle, Array[Long]) = { + writer match { + case rapidsWriter: RapidsLocalDiskShuffleMapOutputWriter => + // finishWritePhase() will enable spill + rapidsWriter.finishWritePhase() + val handle = rapidsWriter.getPartialFileHandle().getOrElse { + throw new IllegalStateException("RAPIDS writer should have a handle") + } + val lengths = rapidsWriter.getPartitionLengths() + (handle, lengths) + case _ => + throw new IllegalStateException( + s"Unexpected writer type: ${writer.getClass.getName}. " + + "RapidsShuffleManager should always use RapidsLocalDiskShuffleMapOutputWriter.") + } + } def getBytesInFlight: Long = limiter.getBytesInFlight } @@ -750,12 +1173,12 @@ abstract class RapidsShuffleThreadedReaderBase[K, C]( } } futures.clear() - try { + try { if (fallbackIter != null) { fallbackIter.close() } } catch { - case t: Throwable => + case t: Throwable => if (failedFuture.isEmpty) { failedFuture = Some(t) } else { @@ -772,8 +1195,8 @@ abstract class RapidsShuffleThreadedReaderBase[K, C]( if (fallbackIter != null) { fallbackIter.hasNext } else { - pendingIts.nonEmpty || - fetcherIterator.hasNext || futures.nonEmpty || queued.size() > 0 + pendingIts.nonEmpty || futures.nonEmpty || queued.size() > 0 || + fetcherIterator.hasNext } } @@ -1317,8 +1740,20 @@ class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: Boolean) } lazy val execComponents: Option[ShuffleExecutorComponents] = { - import scala.collection.JavaConverters._ - val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() + // Check if user configured a different ShuffleDataIO plugin + val configuredPlugin = conf.get("spark.shuffle.sort.io.plugin.class", "") + val rapidsPlugin = "org.apache.spark.shuffle.sort.io.RapidsLocalDiskShuffleDataIO" + + if (configuredPlugin.nonEmpty && !configuredPlugin.endsWith("RapidsLocalDiskShuffleDataIO")) { + throw new IllegalArgumentException( + s"RapidsShuffleManager requires 'spark.shuffle.sort.io.plugin.class' to be " + + s"'$rapidsPlugin' or unset, but found '$configuredPlugin'. " + + s"Please update your configuration.") + } + + val rapidsDataIO = new RapidsLocalDiskShuffleDataIO(conf) + val executorComponents = rapidsDataIO.executor() + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap executorComponents.initializeExecutor( conf.getAppId, @@ -1355,15 +1790,15 @@ class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: Boolean) getCatalogOrThrow, server, gpu.dependency.metrics) - case bmssh: BypassMergeSortShuffleHandle[_, _] => - bmssh.dependency match { + case handle: BaseShuffleHandle[_, _, _] => + handle.dependency match { case gpuDep: GpuShuffleDependency[_, _, _] if gpuDep.useMultiThreadedShuffle && rapidsConf.shuffleMultiThreadedWriterThreads > 0 => // use the threaded writer if the number of threads specified is 1 or above, // with 0 threads we fallback to the Spark-provided writer. val handleWithMetrics = new ShuffleHandleWithMetrics( - bmssh.shuffleId, + handle.shuffleId, gpuDep.metrics, // cast the handle with specific generic types due to type-erasure gpuDep.asInstanceOf[GpuShuffleDependency[K, V, V]]) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala index a6027dd93cb..5b31776162b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2025, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -290,8 +290,6 @@ object GpuShuffleExchangeExecBase { val METRIC_DESC_DATA_SIZE = "data size" val METRIC_DATA_READ_SIZE = "dataReadSize" val METRIC_DESC_DATA_READ_SIZE = "data read size" - val METRIC_SHUFFLE_SERIALIZATION_TIME = "rapidsShuffleSerializationTime" - val METRIC_DESC_SHUFFLE_SERIALIZATION_TIME = "RAPIDS shuffle serialization time" val METRIC_SHUFFLE_SER_STREAM_TIME = "rapidsShuffleSerializationStreamTime" val METRIC_DESC_SHUFFLE_SER_STREAM_TIME = "RAPIDS shuffle serialization to output stream time" val METRIC_SHUFFLE_DESERIALIZATION_TIME = "rapidsShuffleDeserializationTime" @@ -301,12 +299,6 @@ object GpuShuffleExchangeExecBase { "RAPIDS shuffle deserialization from input stream time" val METRIC_SHUFFLE_PARTITION_TIME = "rapidsShufflePartitionTime" val METRIC_DESC_SHUFFLE_PARTITION_TIME = "RAPIDS shuffle partition time" - val METRIC_SHUFFLE_WRITE_TIME = "rapidsShuffleWriteTime" - val METRIC_DESC_SHUFFLE_WRITE_TIME = "RAPIDS shuffle shuffle write time" - val METRIC_SHUFFLE_COMBINE_TIME = "rapidsShuffleCombineTime" - val METRIC_DESC_SHUFFLE_COMBINE_TIME = "RAPIDS shuffle shuffle combine time" - val METRIC_SHUFFLE_WRITE_IO_TIME = "rapidsShuffleWriteIoTime" - val METRIC_DESC_SHUFFLE_WRITE_IO_TIME = "RAPIDS shuffle shuffle write io time" val METRIC_SHUFFLE_READ_TIME = "rapidsShuffleReadTime" val METRIC_DESC_SHUFFLE_READ_TIME = "RAPIDS shuffle shuffle read time" val METRIC_SHUFFLE_SER_COPY_BUFFER_TIME = "rapidsShuffleSerializationCopyBufferTime" @@ -314,13 +306,18 @@ object GpuShuffleExchangeExecBase { val METRIC_SHUFFLE_STALLED_BY_INPUT_STREAM = "rapidsShuffleStalledByInputStream" val METRIC_DESC_SHUFFLE_STALLED_BY_INPUT_STREAM = "RAPIDS shuffle time stalled by input stream operations" + val METRIC_THREADED_WRITER_LIMITER_WAIT_TIME = "rapidsThreadedWriterLimiterWaitTime" + val METRIC_DESC_THREADED_WRITER_LIMITER_WAIT_TIME = + "threaded writer limiter wait time" + val METRIC_THREADED_WRITER_SERIALIZATION_WAIT_TIME = + "rapidsThreadedWriterSerializationWaitTime" + val METRIC_DESC_THREADED_WRITER_SERIALIZATION_WAIT_TIME = + "threaded writer serialization wait time" def createAdditionalExchangeMetrics(gpu: GpuExec): Map[String, GpuMetric] = Map( // dataSize and dataReadSize are uncompressed, one is on write and the other on read METRIC_DATA_SIZE -> gpu.createSizeMetric(ESSENTIAL_LEVEL, METRIC_DESC_DATA_SIZE), METRIC_DATA_READ_SIZE -> gpu.createSizeMetric(MODERATE_LEVEL, METRIC_DESC_DATA_READ_SIZE), - METRIC_SHUFFLE_SERIALIZATION_TIME -> - gpu.createNanoTimingMetric(DEBUG_LEVEL,METRIC_DESC_SHUFFLE_SERIALIZATION_TIME), METRIC_SHUFFLE_SER_STREAM_TIME -> gpu.createNanoTimingMetric(DEBUG_LEVEL, METRIC_DESC_SHUFFLE_SER_STREAM_TIME), METRIC_SHUFFLE_DESERIALIZATION_TIME -> @@ -329,18 +326,18 @@ object GpuShuffleExchangeExecBase { gpu.createNanoTimingMetric(DEBUG_LEVEL, METRIC_DESC_SHUFFLE_DESER_STREAM_TIME), METRIC_SHUFFLE_PARTITION_TIME -> gpu.createNanoTimingMetric(DEBUG_LEVEL, METRIC_DESC_SHUFFLE_PARTITION_TIME), - METRIC_SHUFFLE_WRITE_TIME -> - gpu.createNanoTimingMetric(ESSENTIAL_LEVEL, METRIC_DESC_SHUFFLE_WRITE_TIME), - METRIC_SHUFFLE_COMBINE_TIME -> - gpu.createNanoTimingMetric(DEBUG_LEVEL, METRIC_DESC_SHUFFLE_COMBINE_TIME), - METRIC_SHUFFLE_WRITE_IO_TIME -> - gpu.createNanoTimingMetric(DEBUG_LEVEL, METRIC_DESC_SHUFFLE_WRITE_IO_TIME), METRIC_SHUFFLE_READ_TIME -> gpu.createNanoTimingMetric(ESSENTIAL_LEVEL, METRIC_DESC_SHUFFLE_READ_TIME), METRIC_SHUFFLE_SER_COPY_BUFFER_TIME -> gpu.createNanoTimingMetric(DEBUG_LEVEL, METRIC_DESC_SHUFFLE_SER_COPY_BUFFER_TIME), METRIC_SHUFFLE_STALLED_BY_INPUT_STREAM -> - gpu.createNanoTimingMetric(DEBUG_LEVEL, METRIC_DESC_SHUFFLE_STALLED_BY_INPUT_STREAM) + gpu.createNanoTimingMetric(DEBUG_LEVEL, METRIC_DESC_SHUFFLE_STALLED_BY_INPUT_STREAM), + METRIC_THREADED_WRITER_LIMITER_WAIT_TIME -> + gpu.createNanoTimingMetric(DEBUG_LEVEL, + METRIC_DESC_THREADED_WRITER_LIMITER_WAIT_TIME), + METRIC_THREADED_WRITER_SERIALIZATION_WAIT_TIME -> + gpu.createNanoTimingMetric(DEBUG_LEVEL, + METRIC_DESC_THREADED_WRITER_SERIALIZATION_WAIT_TIME) ) def prepareBatchShuffleDependency( diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsShuffleWriter.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsShuffleWriter.scala index a5a9303bfae..d357c822ffb 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsShuffleWriter.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsShuffleWriter.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024-2025, NVIDIA CORPORATION. + * Copyright (c) 2024-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids -import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.shuffle.{RapidsShuffleServer, RapidsShuffleTransport} @@ -58,6 +58,7 @@ import com.nvidia.spark.rapids.shuffle.{RapidsShuffleServer, RapidsShuffleTransp import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.ShuffleWriter +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter import org.apache.spark.storage._ @@ -65,7 +66,10 @@ abstract class RapidsShuffleWriter[K, V]() extends ShuffleWriter[K, V] with Logging { protected var myMapStatus: Option[MapStatus] = None - protected val diskBlockObjectWriters = new mutable.HashMap[Int, (Int, DiskBlockObjectWriter)]() + + // Track all ShuffleMapOutputWriters created during write + // Needed for proper cleanup on error or for partial files + protected val mapOutputWriters = new ArrayBuffer[ShuffleMapOutputWriter]() /** * Are we in the process of stopping? Because map tasks can call stop() with success = true * and then call stop() with success = false if they get an exception, we want to make sure @@ -98,17 +102,18 @@ abstract class RapidsShuffleWriter[K, V]() } } } - + private def cleanupTempData(): Unit = { - // The map task failed, so delete our output data. - try { - diskBlockObjectWriters.values.foreach { case (_, writer) => - val file = writer.revertPartialWritesAndClose() - if (!file.delete()) logError(s"Error while deleting file ${file.getAbsolutePath()}") + // Abort all map output writers to clean up temp files + mapOutputWriters.foreach { writer => + try { + writer.abort(null) + } catch { + case e: Exception => + logWarning(s"Failed to abort map output writer: ${e.getMessage}") } - } finally { - diskBlockObjectWriters.clear() } + mapOutputWriters.clear() } } diff --git a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/RapidsShuffleWriter.scala b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/RapidsShuffleWriter.scala index 5f92fc0f26a..c080ff9ae9c 100644 --- a/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/RapidsShuffleWriter.scala +++ b/sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/RapidsShuffleWriter.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024-2025, NVIDIA CORPORATION. + * Copyright (c) 2024-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids -import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.shuffle.{RapidsShuffleServer, RapidsShuffleTransport} @@ -28,13 +28,17 @@ import com.nvidia.spark.rapids.shuffle.{RapidsShuffleServer, RapidsShuffleTransp import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatusWithStats import org.apache.spark.shuffle.ShuffleWriter +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter import org.apache.spark.storage._ abstract class RapidsShuffleWriter[K, V]() extends ShuffleWriter[K, V] with Logging { protected var myMapStatus: Option[MapStatusWithStats] = None - protected val diskBlockObjectWriters = new mutable.HashMap[Int, (Int, DiskBlockObjectWriter)]() + + // Track all ShuffleMapOutputWriters created during write + // Needed for proper cleanup on error or for partial files + protected val mapOutputWriters = new ArrayBuffer[ShuffleMapOutputWriter]() /** * Are we in the process of stopping? Because map tasks can call stop() with success = true * and then call stop() with success = false if they get an exception, we want to make sure @@ -67,17 +71,18 @@ abstract class RapidsShuffleWriter[K, V]() } } } - + private def cleanupTempData(): Unit = { - // The map task failed, so delete our output data. - try { - diskBlockObjectWriters.values.foreach { case (_, writer) => - val file = writer.revertPartialWritesAndClose() - if (!file.delete()) logError(s"Error while deleting file ${file.getAbsolutePath()}") + // Abort all map output writers to clean up temp files + mapOutputWriters.foreach { writer => + try { + writer.abort(null) + } catch { + case e: Exception => + logWarning(s"Failed to abort map output writer: ${e.getMessage}") } - } finally { - diskBlockObjectWriters.clear() } + mapOutputWriters.clear() } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/LogCaptureUtils.scala b/tests/src/test/scala/com/nvidia/spark/rapids/LogCaptureUtils.scala new file mode 100644 index 00000000000..5793f1fdff8 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/LogCaptureUtils.scala @@ -0,0 +1,256 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.lang.{Boolean => JBoolean} +import java.lang.reflect.{InvocationHandler, Method, Proxy} + +import scala.collection.mutable.ArrayBuffer + +/** + * Utility for capturing log messages during test execution. + * Automatically detects and works with both Log4j 1.x and Log4j 2.x. + */ +object LogCaptureUtils { + + private val isLog4j2: Boolean = { + try { + Class.forName("org.apache.logging.log4j.core.LoggerContext") + true + } catch { + case _: ClassNotFoundException => false + } + } + + /** + * Capture log messages from specific loggers during operation execution. + * + * @param loggerNames Names of loggers to capture messages from + * @param operation The operation to execute while capturing logs + * @return Array of captured log messages + */ + def captureLogsFrom(loggerNames: Seq[String])(operation: => Unit): Array[String] = { + val logMessages = new ArrayBuffer[String]() + + val capturer = if (isLog4j2) { + new Log4j2Capturer(loggerNames, logMessages) + } else { + new Log4j1Capturer(loggerNames, logMessages) + } + + try { + capturer.setup() + operation + } finally { + capturer.cleanup() + } + + logMessages.toArray + } +} + +/** + * Base trait for log capture implementations. + */ +private trait LogCapturer { + def setup(): Unit + def cleanup(): Unit +} + +/** + * Log capturer implementation for Log4j 1.x using reflection. + */ +private class Log4j1Capturer( + loggerNames: Seq[String], + logMessages: ArrayBuffer[String]) extends LogCapturer { + + private val loggerClass = Class.forName("org.apache.log4j.Logger") + private val levelClass = Class.forName("org.apache.log4j.Level") + private val debugLevel = levelClass.getField("DEBUG").get(null) + + private val getLoggerMethod = loggerClass.getMethod("getLogger", classOf[String]) + private val getRootLoggerMethod = loggerClass.getMethod("getRootLogger") + private val getLevelMethod = loggerClass.getMethod("getLevel") + private val setLevelMethod = loggerClass.getMethod("setLevel", levelClass) + + private val loggers = loggerNames.map(name => + getLoggerMethod.invoke(null, name)) + private val rootLogger = getRootLoggerMethod.invoke(null) + private val origLevels = loggers.map(logger => + getLevelMethod.invoke(logger)) + + private val appenderClass = Class.forName("org.apache.log4j.Appender") + + private val appender = Proxy.newProxyInstance( + getClass.getClassLoader, + Array(appenderClass), + new InvocationHandler { + override def invoke(proxy: Any, method: Method, + args: Array[Object]): Object = { + method.getName match { + case "doAppend" if args != null && args.length > 0 => + val event = args(0) + val getRenderedMessageMethod = event.getClass.getMethod( + "getRenderedMessage") + val message = getRenderedMessageMethod.invoke(event).toString + logMessages.synchronized { logMessages += message } + null + case "getName" => "TestCaptureAppender" + case "close" => null + case "requiresLayout" => JBoolean.FALSE + case "equals" => + if (args != null && args.length == 1) { + JBoolean.valueOf( + proxy.asInstanceOf[AnyRef] eq args(0).asInstanceOf[AnyRef]) + } else { + JBoolean.FALSE + } + case "hashCode" => + Integer.valueOf(System.identityHashCode(proxy)) + case "toString" => "TestCaptureAppender" + case _ => null + } + } + } + ) + + override def setup(): Unit = { + loggers.foreach(logger => setLevelMethod.invoke(logger, debugLevel)) + val addAppenderMethod = rootLogger.getClass.getMethod( + "addAppender", appenderClass) + addAppenderMethod.invoke(rootLogger, appender) + } + + override def cleanup(): Unit = { + val removeAppenderMethod = rootLogger.getClass.getMethod( + "removeAppender", appenderClass) + removeAppenderMethod.invoke(rootLogger, appender) + + loggers.zip(origLevels).foreach { case (logger, origLevel) => + if (origLevel != null) { + setLevelMethod.invoke(logger, origLevel) + } + } + } +} + +/** + * Log capturer implementation for Log4j 2.x using reflection. + */ +private class Log4j2Capturer( + loggerNames: Seq[String], + logMessages: ArrayBuffer[String]) extends LogCapturer { + + private val logManagerClass = Class.forName("org.apache.logging.log4j.LogManager") + private val getContextMethod = logManagerClass.getMethod( + "getContext", classOf[Boolean]) + private val context = getContextMethod.invoke(null, JBoolean.FALSE) + + private val getConfigurationMethod = context.getClass.getMethod("getConfiguration") + private val config = getConfigurationMethod.invoke(context) + + private val levelClass = Class.forName("org.apache.logging.log4j.Level") + private val debugLevel = levelClass.getField("DEBUG").get(null) + + private val appenderClass = Class.forName( + "org.apache.logging.log4j.core.Appender") + + private val appender = Proxy.newProxyInstance( + getClass.getClassLoader, + Array(appenderClass), + new InvocationHandler { + override def invoke(proxy: Any, method: Method, + args: Array[Object]): Object = { + method.getName match { + case "append" if args != null && args.length > 0 => + val logEvent = args(0) + val getMessageMethod = logEvent.getClass.getMethod("getMessage") + val message = getMessageMethod.invoke(logEvent) + val getFormattedMessageMethod = message.getClass.getMethod( + "getFormattedMessage") + val formattedMsg = getFormattedMessageMethod.invoke(message).toString + logMessages.synchronized { logMessages += formattedMsg } + null + case "getName" => "TestCaptureAppender" + case "isStarted" => JBoolean.TRUE + case "isStopped" => JBoolean.FALSE + case "getLayout" => null + case "ignoreExceptions" => JBoolean.TRUE + case "equals" => + if (args != null && args.length == 1) { + JBoolean.valueOf( + proxy.asInstanceOf[AnyRef] eq args(0).asInstanceOf[AnyRef]) + } else { + JBoolean.FALSE + } + case "hashCode" => + Integer.valueOf(System.identityHashCode(proxy)) + case "toString" => "TestCaptureAppender" + case _ => null + } + } + } + ) + + private val getLoggerConfigMethod = config.getClass.getMethod( + "getLoggerConfig", classOf[String]) + + private val origLevels = loggerNames.map { name => + val loggerConfig = getLoggerConfigMethod.invoke(config, name) + val getLevelMethod = loggerConfig.getClass.getMethod("getLevel") + val origLevel = getLevelMethod.invoke(loggerConfig) + (loggerConfig, origLevel) + } + + override def setup(): Unit = { + loggerNames.foreach { name => + val loggerConfig = getLoggerConfigMethod.invoke(config, name) + val setLevelMethod = loggerConfig.getClass.getMethod( + "setLevel", levelClass) + setLevelMethod.invoke(loggerConfig, debugLevel) + + val addAppenderMethod = loggerConfig.getClass.getMethod( + "addAppender", appenderClass, levelClass, + Class.forName("org.apache.logging.log4j.core.Filter")) + addAppenderMethod.invoke(loggerConfig, appender, debugLevel, null) + } + + val updateLoggersMethod = context.getClass.getMethod("updateLoggers") + updateLoggersMethod.invoke(context) + } + + override def cleanup(): Unit = { + loggerNames.foreach { name => + val loggerConfig = getLoggerConfigMethod.invoke(config, name) + val removeAppenderMethod = loggerConfig.getClass.getMethod( + "removeAppender", classOf[String]) + removeAppenderMethod.invoke(loggerConfig, "TestCaptureAppender") + } + + origLevels.foreach { case (loggerConfig, origLevel) => + if (origLevel != null) { + val setLevelMethod = loggerConfig.getClass.getMethod( + "setLevel", levelClass) + setLevelMethod.invoke(loggerConfig, origLevel) + } + } + + val updateLoggersMethod = context.getClass.getMethod("updateLoggers") + updateLoggersMethod.invoke(context) + } +} + diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsShuffleIntegrationSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsShuffleIntegrationSuite.scala new file mode 100644 index 00000000000..d330e7056c0 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsShuffleIntegrationSuite.scala @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession + +/** + * Integration test suite for RAPIDS Shuffle Manager. + * Tests buffer fallback behavior under different configurations: + * - Verifies SpillablePartialFileHandle memory-to-disk fallback + * - Tests different partialFileBufferMaxSize settings + * - Validates forced file-only mode for finalMergeWriter + */ +class RapidsShuffleIntegrationSuite extends AnyFunSuite with BeforeAndAfterEach { + + private var spark: SparkSession = _ + + /** + * Capture log messages during operation and check for buffer behavior indicators. + * + * @param operation The operation to execute while capturing logs + * @return Tuple of (hasExpansion, hasSpill, hasForcedFileOnly) + * - hasExpansion: Whether buffer expansion occurred + * - hasSpill: Whether buffer spilled to disk + * - hasForcedFileOnly: Whether forced file-only mode was used + */ + private def checkLogsForBufferBehavior(operation: => Unit): + (Boolean, Boolean, Boolean) = { + val loggerNames = Seq( + "com.nvidia.spark.rapids.spill.SpillablePartialFileHandle", + "org.apache.spark.shuffle.sort.io.RapidsLocalDiskShuffleMapOutputWriter" + ) + + val logMessages = LogCaptureUtils.captureLogsFrom(loggerNames) { + operation + } + + val hasExpansion = logMessages.exists(_.contains("Expanded buffer from")) + val hasSpill = logMessages.exists(_.contains("Spilled buffer to")) + val hasForcedFileOnly = logMessages.exists(_.contains("Using forced file-only mode")) + + (hasExpansion, hasSpill, hasForcedFileOnly) + } + + /** + * Create SparkSession with custom partialFileBufferMaxSize. + * Required because partialFileBufferMaxSize is .startupOnly() config. + * + * @param maxBufferSize The max buffer size (e.g., "1m", "6m") + */ + private def createSessionWithBufferConfig(maxBufferSize: String): Unit = { + if (spark != null) { + spark.stop() + } + SparkSession.clearActiveSession() + + val shimVersion = ShimLoader.getShimVersion + val shuffleManagerClass = s"com.nvidia.spark.rapids.spark" + + s"${shimVersion.toString.replace(".", "")}." + + "RapidsShuffleManager" + + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("RapidsShuffleIntegrationTest") + .set("spark.plugins", "com.nvidia.spark.SQLPlugin") + .set("spark.rapids.sql.enabled", "true") + .set("spark.rapids.sql.test.enabled", "true") + .set("spark.shuffle.manager", shuffleManagerClass) + .set("spark.shuffle.sort.io.plugin.class", + "org.apache.spark.shuffle.sort.io.RapidsLocalDiskShuffleDataIO") + .set("spark.rapids.memory.host.partialFileBufferInitialSize", "1m") + .set("spark.rapids.memory.host.partialFileBufferMaxSize", maxBufferSize) + .set("spark.rapids.sql.batchSizeBytes", "5m") + + spark = SparkSession.builder().config(conf).getOrCreate() + } + + override def beforeEach(): Unit = { + super.beforeEach() + + // Clean up any existing session before each test + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + override def afterEach(): Unit = { + try { + if (spark != null) { + spark.stop() + spark = null + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + super.afterEach() + } + } + + test("shuffle with join - 1MB max buffer triggers fallback") { + // Test verifying buffer fallback with small max buffer size + // Setup: 1MB max buffer, 1MB initial buffer, 5MB GPU batch size + // Expected: Serialized data (~4MB) exceeds 1MB limit, triggers spill to disk + createSessionWithBufferConfig("1m") + + val (hasExpansion, hasSpill, hasForcedFileOnly) = checkLogsForBufferBehavior { + val df1 = spark.range(0, 3000000, 1, 3) + .selectExpr( + "id as key", + "id * 2 as value1", + "concat('left_', cast(id as string)) as left_str" + ) + + val df2 = spark.range(0, 30000000, 1, 3) + .selectExpr( + "id as key", + "id * 3 as value2", + "concat('right_', cast(id as string)) as right_str" + ) + + val result = df1.join(df2, "key") + .selectExpr("key", "value1", "value2") + .collect() + + assert(result.length == 3000000, "Should have 3M joined rows") + } + + // Verify: buffer cannot expand beyond 1MB, must spill to disk + assert(!hasExpansion && hasSpill && hasForcedFileOnly, + s"Expected NO expansion and spill with 1MB max buffer. " + + s"expansion=$hasExpansion, spill=$hasSpill, forcedFileOnly=$hasForcedFileOnly") + } + + test("shuffle with join - 6MB max buffer avoids fallback") { + // Test verifying buffer stays in memory with sufficient max buffer size + // Setup: 6MB max buffer, 1MB initial buffer, 5MB GPU batch size + // Expected: Serialized data (~4MB) fits within 6MB limit via buffer expansion + createSessionWithBufferConfig("6m") + + val (hasExpansion, hasSpill, hasForcedFileOnly) = checkLogsForBufferBehavior { + val df1 = spark.range(0, 3000000, 1, 3) + .selectExpr( + "id as key", + "id * 2 as value1", + "concat('left_', cast(id as string)) as left_str" + ) + + val df2 = spark.range(0, 30000000, 1, 3) + .selectExpr( + "id as key", + "id * 3 as value2", + "concat('right_', cast(id as string)) as right_str" + ) + + val result = df1.join(df2, "key") + .selectExpr("key", "value1", "value2") + .collect() + + assert(result.length == 3000000, "Should have 3M joined rows") + } + + // Verify: buffer expands to accommodate data, no spill to disk + assert(hasExpansion && !hasSpill && hasForcedFileOnly, + s"Expected expansion, NO spill, forced file-only with 6MB max buffer. " + + s"expansion=$hasExpansion, spill=$hasSpill, forcedFileOnly=$hasForcedFileOnly") + } +} + diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala index fe5c3f7446f..4aadaac4941 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024-2025, NVIDIA CORPORATION. + * Copyright (c) 2024-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -510,9 +510,9 @@ class SpillFrameworkSuite } test("host originated: get host memory buffer") { - val spillPriority = -10 val hmb = HostMemoryBuffer.allocate(1L * 1024) - val spillableBuffer = SpillableHostBuffer(hmb, hmb.getLength, spillPriority) + val spillableBuffer = SpillableHostBuffer(hmb, hmb.getLength, + SpillPriorities.ACTIVE_BATCHING_PRIORITY) withResource(spillableBuffer) { _ => // the refcount of 1 is the store assertResult(1)(hmb.getRefCount) @@ -525,12 +525,11 @@ class SpillFrameworkSuite } test("host originated: get host memory buffer after spill to disk") { - val spillPriority = -10 val hmb = HostMemoryBuffer.allocate(1L * 1024) val spillableBuffer = SpillableHostBuffer( hmb, hmb.getLength, - spillPriority) + SpillPriorities.ACTIVE_BATCHING_PRIORITY) assertResult(1)(hmb.getRefCount) // we spill it SpillFramework.stores.hostStore.spill(hmb.getLength) @@ -544,9 +543,9 @@ class SpillFrameworkSuite } test("host originated: a buffer is not spillable when we leak it") { - val spillPriority = -10 val hmb = HostMemoryBuffer.allocate(1L * 1024) - withResource(SpillableHostBuffer(hmb, hmb.getLength, spillPriority)) { spillableBuffer => + withResource(SpillableHostBuffer(hmb, hmb.getLength, + SpillPriorities.ACTIVE_BATCHING_PRIORITY)) { spillableBuffer => withResource(spillableBuffer.getHostBuffer()) { _ => assertResult(0)(SpillFramework.stores.hostStore.spill(hmb.getLength)) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala new file mode 100644 index 00000000000..7f0b59ca059 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillablePartialFileHandleSuite.scala @@ -0,0 +1,576 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.spill + +import java.io.File +import java.util.Arrays + +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.RapidsConf +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +class SpillablePartialFileHandleSuite extends AnyFunSuite with BeforeAndAfterEach { + + // Use 1GB max buffer size for tests to avoid memory issues on test machines + private val testMaxBufferSize = 1L * 1024 * 1024 * 1024 + + override def beforeEach(): Unit = { + super.beforeEach() + val conf = new RapidsConf(Map( + "spark.rapids.memory.host.partialFileBufferMaxSize" -> testMaxBufferSize.toString + )) + SpillFramework.initialize(conf) + } + + override def afterEach(): Unit = { + SpillFramework.shutdown() + super.afterEach() + } + + test("FILE_ONLY mode: write and read") { + val tempFile = File.createTempFile("test-file-only-", ".tmp") + + withResource(SpillablePartialFileHandle.createFileOnly(tempFile)) { handle => + // Write some data + val testData = "Hello, World! This is a test.".getBytes("UTF-8") + handle.write(testData, 0, testData.length) + + assert(!handle.isMemoryBased, "FILE_ONLY should not be memory based") + assert(!handle.isSpilled, "FILE_ONLY should never report as spilled") + + // Finish write phase + handle.finishWrite() + + assert(handle.getTotalBytesWritten == testData.length) + + // Read data back + val readBuffer = new Array[Byte](testData.length) + val bytesRead = handle.read(readBuffer, 0, testData.length) + + assert(bytesRead == testData.length) + assert(readBuffer.sameElements(testData)) + + // EOF check + assert(handle.read(readBuffer, 0, 10) == -1) + } + } + + test("MEMORY_WITH_SPILL mode: write and read from memory") { + val tempFile = File.createTempFile("test-memory-", ".tmp") + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 1024, + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.5, + spillFile = tempFile)) { handle => + + assert(handle.isMemoryBased, "Should be memory based") + assert(!handle.isSpilled, "Should not be spilled initially") + + // Write small amount of data (fits in buffer) + val testData = "Small test data".getBytes("UTF-8") + handle.write(testData, 0, testData.length) + + // Finish write phase + handle.finishWrite() + + assert(handle.getTotalBytesWritten == testData.length) + assert(!handle.isSpilled, "Should still be in memory") + + // Read data back + val readBuffer = new Array[Byte](testData.length) + val bytesRead = handle.read(readBuffer, 0, testData.length) + + assert(bytesRead == testData.length) + assert(readBuffer.sameElements(testData)) + } + } + + test("MEMORY_WITH_SPILL mode: buffer expansion") { + val tempFile = File.createTempFile("test-expansion-", ".tmp") + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 64, // Small initial size to force expansion + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.5, + spillFile = tempFile)) { handle => + + // Write data larger than initial capacity + val largeData = new Array[Byte](200) + (0 until 200).foreach(i => largeData(i) = (i % 256).toByte) + + handle.write(largeData, 0, largeData.length) + + // Should have expanded the buffer + handle.finishWrite() + + assert(handle.getTotalBytesWritten == largeData.length) + + // Read and verify + val readBuffer = new Array[Byte](largeData.length) + val bytesRead = handle.read(readBuffer, 0, largeData.length) + + assert(!handle.isSpilled, "Should still be in memory") + assert(bytesRead == largeData.length) + assert(readBuffer.sameElements(largeData)) + } + } + + test("MEMORY_WITH_SPILL mode: buffer expansion then switch to file") { + val tempFile = File.createTempFile("test-expansion-switch-", ".tmp") + + // Use 600MB initial capacity, so doubling (1.2GB) would exceed 1GB limit + val initialCapacity = 600L * 1024 * 1024 + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = initialCapacity, + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.5, + spillFile = tempFile)) { handle => + + // Write data to fill initial buffer + val chunkSize = 1024 * 1024 // 1MB chunks + val chunk = new Array[Byte](chunkSize) + (0 until chunkSize).foreach(i => chunk(i) = (i % 256).toByte) + + // Fill the buffer completely + val numChunks = (testMaxBufferSize / chunkSize).toInt + (0 until numChunks).foreach { _ => + handle.write(chunk, 0, chunkSize) + } + + // Write one more byte to trigger expansion, which should fail + // and switch to file mode due to testMaxBufferSize limit + handle.write(0xFF) + + // Write more data to verify file mode is working + handle.write(chunk, 0, chunkSize) + + handle.finishWrite() + + val expectedSize = numChunks.toLong * chunkSize + 1 + chunkSize + assert(handle.getTotalBytesWritten == expectedSize) + assert(handle.isMemoryBased, "Should still be memory-based mode") + assert(handle.isSpilled, "Should have switched to file after expansion") + + // Verify we can read all data back correctly + val readBuffer = new Array[Byte](chunkSize) + + // Read first chunks + (0 until numChunks).foreach { _ => + val bytesRead = handle.read(readBuffer, 0, chunkSize) + assert(bytesRead == chunkSize) + assert(readBuffer.sameElements(chunk)) + } + + // Read the single byte + val singleByte = new Array[Byte](1) + assert(handle.read(singleByte, 0, 1) == 1) + assert(singleByte(0) == 0xFF.toByte) + + // Read last chunk + val lastBytesRead = handle.read(readBuffer, 0, chunkSize) + assert(lastBytesRead == chunkSize) + assert(readBuffer.sameElements(chunk)) + + // EOF check + assert(handle.read(readBuffer, 0, chunkSize) == -1) + } + } + + test("MEMORY_WITH_SPILL mode: manual spill") { + val tempFile = File.createTempFile("test-manual-spill-", ".tmp") + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 1024, + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.5, + spillFile = tempFile)) { handle => + + val testData = "Data to be spilled".getBytes("UTF-8") + handle.write(testData, 0, testData.length) + handle.finishWrite() + + assert(!handle.isSpilled, "Should not be spilled yet") + + // Manually spill + val spilledBytes = handle.spill() + assert(spilledBytes == testData.length) + assert(handle.isSpilled, "Should be spilled now") + + // Read from spilled file + val readBuffer = new Array[Byte](testData.length) + val bytesRead = handle.read(readBuffer, 0, testData.length) + + assert(bytesRead == testData.length) + assert(readBuffer.sameElements(testData)) + } + } + + test("MEMORY_WITH_SPILL mode: sequential write with single bytes") { + val tempFile = File.createTempFile("test-single-bytes-", ".tmp") + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 128, + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.5, + spillFile = tempFile)) { handle => + + // Write bytes one by one + val testBytes = Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + testBytes.foreach(b => handle.write(b.toInt)) + + handle.finishWrite() + + assert(handle.getTotalBytesWritten == testBytes.length) + + // Read back + val readBuffer = new Array[Byte](testBytes.length) + val bytesRead = handle.read(readBuffer, 0, testBytes.length) + + assert(bytesRead == testBytes.length) + assert(readBuffer.sameElements(testBytes)) + } + } + + test("FILE_ONLY mode: multiple partitions sequential read") { + val tempFile = File.createTempFile("test-partitions-", ".tmp") + + withResource(SpillablePartialFileHandle.createFileOnly(tempFile)) { handle => + // Simulate writing 3 partitions + val partition0 = "Partition 0 data".getBytes("UTF-8") + val partition1 = "Partition 1 data".getBytes("UTF-8") + val partition2 = "Partition 2 data".getBytes("UTF-8") + + handle.write(partition0, 0, partition0.length) + handle.write(partition1, 0, partition1.length) + handle.write(partition2, 0, partition2.length) + + handle.finishWrite() + + val totalSize = partition0.length + partition1.length + partition2.length + assert(handle.getTotalBytesWritten == totalSize) + + // Read partition 0 + val read0 = new Array[Byte](partition0.length) + assert(handle.read(read0, 0, partition0.length) == partition0.length) + assert(read0.sameElements(partition0)) + + // Read partition 1 + val read1 = new Array[Byte](partition1.length) + assert(handle.read(read1, 0, partition1.length) == partition1.length) + assert(read1.sameElements(partition1)) + + // Read partition 2 + val read2 = new Array[Byte](partition2.length) + assert(handle.read(read2, 0, partition2.length) == partition2.length) + assert(read2.sameElements(partition2)) + + // EOF + assert(handle.read(read0, 0, 10) == -1) + } + } + + test("Error handling: write after finish should fail") { + val tempFile = File.createTempFile("test-error-", ".tmp") + + withResource(SpillablePartialFileHandle.createFileOnly(tempFile)) { handle => + handle.write("test".getBytes("UTF-8"), 0, 4) + handle.finishWrite() + + assertThrows[IllegalStateException] { + handle.write("more".getBytes("UTF-8"), 0, 4) + } + } + } + + test("Error handling: read before finish should fail") { + val tempFile = File.createTempFile("test-error2-", ".tmp") + + withResource(SpillablePartialFileHandle.createFileOnly(tempFile)) { handle => + handle.write("test".getBytes("UTF-8"), 0, 4) + + val buffer = new Array[Byte](10) + assertThrows[IllegalStateException] { + handle.read(buffer, 0, 10) + } + } + } + + test("MEMORY_WITH_SPILL mode: chunked read") { + val tempFile = File.createTempFile("test-chunked-", ".tmp") + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 1024, + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.5, + spillFile = tempFile)) { handle => + + // Write 100 bytes + val testData = new Array[Byte](100) + (0 until 100).foreach(i => testData(i) = (i % 256).toByte) + handle.write(testData, 0, testData.length) + handle.finishWrite() + + // Read in chunks of 30 bytes + val allRead = new scala.collection.mutable.ArrayBuffer[Byte]() + val chunkSize = 30 + val readBuffer = new Array[Byte](chunkSize) + + var bytesRead = handle.read(readBuffer, 0, chunkSize) + while (bytesRead > 0) { + allRead ++= readBuffer.take(bytesRead) + bytesRead = handle.read(readBuffer, 0, chunkSize) + } + + assert(allRead.length == testData.length) + assert(allRead.toArray.sameElements(testData)) + } + } + + test("MEMORY_WITH_SPILL mode: mixed single byte and array writes") { + val tempFile = File.createTempFile("test-mixed-write-", ".tmp") + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 256, + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.5, + spillFile = tempFile)) { handle => + + // Mix single byte writes and array writes + handle.write(0x01) + handle.write(0x02) + + val chunk1 = Array[Byte](0x03, 0x04, 0x05, 0x06) + handle.write(chunk1, 0, chunk1.length) + + handle.write(0x07) + + val chunk2 = Array[Byte](0x08, 0x09, 0x0A) + handle.write(chunk2, 0, chunk2.length) + + handle.finishWrite() + + // Verify total bytes written + assert(handle.getTotalBytesWritten == 10) + + // Read all data back + val readBuffer = new Array[Byte](10) + val bytesRead = handle.read(readBuffer, 0, 10) + + assert(bytesRead == 10) + val expected = Array[Byte](0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A) + assert(readBuffer.sameElements(expected)) + } + } + + test("MEMORY_WITH_SPILL mode: read after spill during read phase") { + val tempFile = File.createTempFile("test-spill-during-read-", ".tmp") + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 1024, + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.5, + spillFile = tempFile)) { handle => + + // Write some test data + val testData = new Array[Byte](500) + (0 until 500).foreach(i => testData(i) = (i % 256).toByte) + handle.write(testData, 0, testData.length) + handle.finishWrite() + + assert(!handle.isSpilled, "Should not be spilled initially") + + // Read first half + val firstHalf = new Array[Byte](250) + val firstRead = handle.read(firstHalf, 0, 250) + assert(firstRead == 250) + + // Manually spill after reading first half + val spilledBytes = handle.spill() + assert(spilledBytes == testData.length) + assert(handle.isSpilled, "Should be spilled now") + + // Continue reading second half from spilled file + val secondHalf = new Array[Byte](250) + val secondRead = handle.read(secondHalf, 0, 250) + assert(secondRead == 250) + + // Verify both halves are correct + assert(firstHalf.sameElements(testData.slice(0, 250))) + assert(secondHalf.sameElements(testData.slice(250, 500))) + } + } + + test("MEMORY_WITH_SPILL mode: single byte write triggers expansion") { + val tempFile = File.createTempFile("test-single-byte-expansion-", ".tmp") + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 16, // Very small initial size + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.5, + spillFile = tempFile)) { handle => + + // Write bytes one by one to exceed initial capacity + val testData = new Array[Byte](50) + (0 until 50).foreach { i => + testData(i) = (i % 256).toByte + handle.write(testData(i).toInt) + } + + handle.finishWrite() + + assert(handle.getTotalBytesWritten == 50) + assert(!handle.isSpilled, "Should have expanded buffer, not spilled") + + // Read and verify + val readBuffer = new Array[Byte](50) + val bytesRead = handle.read(readBuffer, 0, 50) + + assert(bytesRead == 50) + assert(readBuffer.sameElements(testData)) + } + } + + test("MEMORY_WITH_SPILL mode: write phase protected from spill") { + val tempFile = File.createTempFile("test-spill-protection-", ".tmp") + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 1024, + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.5, + spillFile = tempFile)) { handle => + + // During write phase, handle should not be spillable + assert(!handle.spillable, "Should not be spillable during write phase") + + // Write some data + val testData = "Protected data".getBytes("UTF-8") + handle.write(testData, 0, testData.length) + + // Still protected + assert(!handle.spillable, "Should still not be spillable during write") + + // Attempt to spill during write phase should do nothing + val spilledBytes = handle.spill() + assert(spilledBytes == 0, "Should not spill during write phase") + assert(!handle.isSpilled, "Should not be spilled") + + // Finish write phase + handle.finishWrite() + + // After finish, should be spillable + assert(handle.spillable, "Should be spillable after write phase") + + // Now spill should work + val spilledAfterFinish = handle.spill() + assert(spilledAfterFinish == testData.length) + assert(handle.isSpilled, "Should be spilled now") + + // Verify data is still readable after spill + val readBuffer = new Array[Byte](testData.length) + handle.read(readBuffer, 0, testData.length) + assert(readBuffer.sameElements(testData)) + } + } + + test("Disk write savings metric - MEMORY_WITH_SPILL kept in memory") { + // This test verifies that when data stays in memory during write phase, + // the metric logic correctly identifies it as "saved disk write" + val tempFile = File.createTempFile("test-metric-", ".tmp") + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 1024 * 1024, // 1MB + maxBufferSize = testMaxBufferSize, + memoryThreshold = 0.9, // High threshold to avoid spill + spillFile = tempFile)) { handle => + + val testData = "Test data for disk write savings metric.".getBytes("UTF-8") + handle.write(testData, 0, testData.length) + + // Verify still in memory + assert(handle.isMemoryBased) + assert(!handle.isSpilled) + + // When finishWrite is called and data is still in memory, + // it should trigger the metric recording logic + handle.finishWrite() + + // Verify the handle state is correct for metric recording + assert(handle.getTotalBytesWritten == testData.length) + assert(handle.isMemoryBased) + assert(!handle.isSpilled) + + // Note: Without a TaskContext in unit tests, the accumulator won't be + // created, but the code should not fail. The actual metric value can + // only be verified in integration tests with real TaskContext. + } + } + + test("Disk write savings metric - FILE_ONLY should not record") { + // This test verifies that FILE_ONLY mode does not record disk write savings + val tempFile = File.createTempFile("test-metric-file-", ".tmp") + + withResource(SpillablePartialFileHandle.createFileOnly(tempFile)) { handle => + val testData = "Test data for FILE_ONLY mode.".getBytes("UTF-8") + handle.write(testData, 0, testData.length) + + // Verify it's file-only mode + assert(!handle.isMemoryBased) + + // When finishWrite is called in FILE_ONLY mode, + // it should NOT trigger metric recording + handle.finishWrite() + + // Verify state + assert(handle.getTotalBytesWritten == testData.length) + assert(!handle.isMemoryBased) + } + } + + test("Disk write savings metric - MEMORY spilled during write should not record") { + // This test verifies that if data spills during write phase, + // it should NOT record disk write savings + val tempFile = File.createTempFile("test-metric-spill-", ".tmp") + + withResource(SpillablePartialFileHandle.createMemoryWithSpill( + initialCapacity = 100, // Very small buffer + maxBufferSize = 200, // Very small max to force spill + memoryThreshold = 0.9, + spillFile = tempFile)) { handle => + + // Write data larger than max buffer size to trigger spill during write + val largeData = new Array[Byte](300) + Arrays.fill(largeData, 42.toByte) + handle.write(largeData, 0, largeData.length) + + // After writing, should have spilled during write phase + assert(handle.isMemoryBased) // Still memory-based mode + assert(handle.isSpilled) // But spilled to disk + + // When finishWrite is called after spilling during write, + // it should NOT record disk write savings + handle.finishWrite() + + // Verify state + assert(handle.getTotalBytesWritten == largeData.length) + assert(handle.isSpilled) + } + } +} + diff --git a/tests/src/test/spark321/scala/org/apache/spark/sql/rapids/RapidsShuffleThreadedWriterSuite.scala b/tests/src/test/spark321/scala/org/apache/spark/sql/rapids/RapidsShuffleThreadedWriterSuite.scala index 8d8a4a8701d..c4acb7ac0d8 100644 --- a/tests/src/test/spark321/scala/org/apache/spark/sql/rapids/RapidsShuffleThreadedWriterSuite.scala +++ b/tests/src/test/spark321/scala/org/apache/spark/sql/rapids/RapidsShuffleThreadedWriterSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,14 +32,17 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids -import java.io.{DataInputStream, File, FileInputStream, IOException, ObjectStreamException} +import java.io._ +import java.nio.ByteBuffer import java.util.UUID -import java.util.zip.CheckedInputStream import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag import scala.util.control.NonFatal +import ai.rapids.cudf.HostMemoryBuffer +import com.nvidia.spark.rapids.SlicedSerializedColumnVector import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} @@ -48,113 +51,101 @@ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar -import org.apache.spark.{HashPartitioner, SparkConf, SparkException, TaskContext} +import org.apache.spark.{HashPartitioner, SparkConf, TaskContext} import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper -import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} +import org.apache.spark.internal.Logging +import org.apache.spark.serializer._ import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.RapidsLocalDiskShuffleExecutorComponents import org.apache.spark.sql.rapids.shims.RapidsShuffleThreadedWriter -import org.apache.spark.storage.{BlockId, BlockManager, DiskBlockManager, DiskBlockObjectWriter, ShuffleChecksumBlockId, ShuffleDataBlockId, ShuffleIndexBlockId, TempShuffleBlockId} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.storage._ import org.apache.spark.util.Utils /** - * This is mostly ported from Apache Spark's BypassMergeSortShuffleWriterSuite - * It is only targetted for Spark 3.2.0+ but it touches all of the code for Spark 3.1.2 - * as well, except for some small tweaks when committing with checksum support. + * A simple serializer for testing ColumnarBatch with SlicedSerializedColumnVector. */ -trait ShuffleChecksumTestHelper { - - /** - * Ensure that the checksum values are consistent between write and read side. - */ - def compareChecksums(numPartition: Int, - algorithm: String, - checksum: File, - data: File, - index: File): Unit = { - assert(checksum.exists(), "Checksum file doesn't exist") - assert(data.exists(), "Data file doesn't exist") - assert(index.exists(), "Index file doesn't exist") - - var checksumIn: DataInputStream = null - val expectChecksums = Array.ofDim[Long](numPartition) - try { - checksumIn = new DataInputStream(new FileInputStream(checksum)) - (0 until numPartition).foreach(i => expectChecksums(i) = checksumIn.readLong()) - } finally { - if (checksumIn != null) { - checksumIn.close() - } - } +class TestColumnarBatchSerializer extends Serializer with Serializable { + override def newInstance(): SerializerInstance = new TestColumnarBatchSerializerInstance() + override def supportsRelocationOfSerializedObjects: Boolean = true +} - var dataIn: FileInputStream = null - var indexIn: DataInputStream = null - var checkedIn: CheckedInputStream = null - try { - dataIn = new FileInputStream(data) - indexIn = new DataInputStream(new FileInputStream(index)) - var prevOffset = indexIn.readLong - (0 until numPartition).foreach { i => - val curOffset = indexIn.readLong - val limit = (curOffset - prevOffset).toInt - val bytes = new Array[Byte](limit) - val checksumCal = ShuffleChecksumHelper.getChecksumByAlgorithm(algorithm) - checkedIn = new CheckedInputStream( - new LimitedInputStream(dataIn, curOffset - prevOffset), checksumCal) - checkedIn.read(bytes, 0, limit) - prevOffset = curOffset - // checksum must be consistent at both write and read sides - assert(checkedIn.getChecksum.getValue == expectChecksums(i)) - } - } finally { - if (dataIn != null) { - dataIn.close() - } - if (indexIn != null) { - indexIn.close() - } - if (checkedIn != null) { - checkedIn.close() - } - } +class TestColumnarBatchSerializerInstance extends SerializerInstance { + override def serialize[T: ClassTag](t: T): ByteBuffer = { + val bos = new ByteArrayOutputStream() + val stream = serializeStream(bos) + stream.writeObject(t) + stream.close() + ByteBuffer.wrap(bos.toByteArray) } + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException("Not implemented for test") + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException("Not implemented for test") + + override def serializeStream(s: OutputStream): SerializationStream = + new TestColumnarBatchSerializationStream(s) + + override def deserializeStream(s: InputStream): DeserializationStream = + throw new UnsupportedOperationException("Not implemented for test") } -class BadSerializable(i: Int) extends Serializable { - @throws(classOf[ObjectStreamException]) - def writeReplace(): Object = { - if (i >= 500) { - throw new IOException(s"failed to write $i") +class TestColumnarBatchSerializationStream(out: OutputStream) extends SerializationStream { + private val dataOut = new DataOutputStream(out) + + override def writeObject[T: ClassTag](t: T): SerializationStream = { + t match { + case batch: ColumnarBatch => + dataOut.writeInt(batch.numCols()) + for (i <- 0 until batch.numCols()) { + batch.column(i) match { + case col: SlicedSerializedColumnVector => + val hmb = col.getWrap + val size = hmb.getLength.toInt + dataOut.writeInt(size) + val bytes = new Array[Byte](size) + hmb.getBytes(bytes, 0, 0, size) + dataOut.write(bytes) + case _ => + dataOut.writeInt(0) + } + } + case key: Int => + dataOut.writeInt(key) + case _ => + dataOut.writeInt(-1) } this } + + override def writeKey[T: ClassTag](key: T): SerializationStream = writeObject(key) + override def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value) + override def flush(): Unit = dataOut.flush() + override def close(): Unit = dataOut.close() } -// Added so that we can mock a method that was created in Spark 3.3.0 -// and the override can be used in all versions of Spark + +// Shim for Spark 3.3.0+ createTempFile method trait ShimIndexShuffleBlockResolver330 { def createTempFile(file: File): File } -class TestIndexShuffleBlockResolver( - conf: SparkConf, - bm: BlockManager) - extends IndexShuffleBlockResolver(conf, bm) - with ShimIndexShuffleBlockResolver330 { - // implemented in Spark 3.3.0 - override def createTempFile(file: File): File = { null } + +class TestIndexShuffleBlockResolver(conf: SparkConf, bm: BlockManager) + extends IndexShuffleBlockResolver(conf, bm) with ShimIndexShuffleBlockResolver330 { + override def createTempFile(file: File): File = null } + class RapidsShuffleThreadedWriterSuite extends AnyFunSuite with BeforeAndAfterEach with BeforeAndAfterAll with MockitoSugar - with ShuffleChecksumTestHelper with Logging { + @scala.annotation.nowarn("msg=consider using immutable val") @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ @@ -169,11 +160,7 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite @scala.annotation.nowarn("msg=consider using immutable val") @Mock(answer = RETURNS_SMART_NULLS) - private var dependency: GpuShuffleDependency[Int, Int, Int] = _ - - @scala.annotation.nowarn("msg=consider using immutable val") - @Mock(answer = RETURNS_SMART_NULLS) - private var dependencyBad: GpuShuffleDependency[Int, BadSerializable, BadSerializable] = _ + private var dependency: GpuShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = _ private var taskMetrics: TaskMetrics = _ private var tempDir: File = _ @@ -183,10 +170,86 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite .set("spark.app.id", "sampleApp") private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] - private var shuffleHandle: ShuffleHandleWithMetrics[Int, Int, Int] = _ + private var shuffleHandle: ShuffleHandleWithMetrics[Int, ColumnarBatch, ColumnarBatch] = _ + + // Track the sliced buffers (wrap) for cleanup, since incRefCountAndGetSize increases refCount + private val slicedBuffersToClean: mutable.Buffer[HostMemoryBuffer] = + new ArrayBuffer[HostMemoryBuffer]() private val numWriterThreads = 2 + private def createTestBatch(value: Int): ColumnarBatch = { + val bufferSize = 64 + (value % 64) + val hmb = HostMemoryBuffer.allocate(bufferSize) + for (i <- 0 until bufferSize) { + hmb.setByte(i, (value + i).toByte) + } + val cv = new SlicedSerializedColumnVector(hmb, 0, bufferSize) + // Save the sliced buffer (wrap) for cleanup, NOT the original hmb + // incRefCountAndGetSize will increase wrap's refCount, we need to close it once more + slicedBuffersToClean += cv.getWrap + // Close original hmb since SlicedSerializedColumnVector.slice() increased its refCount + hmb.close() + new ColumnarBatch(Array(cv), 1) + } + + private def createTestRecords(keys: Iterator[Int]): Iterator[(Int, ColumnarBatch)] = + keys.map(key => (key, createTestBatch(key))) + + private def createWriter(): RapidsShuffleThreadedWriter[Int, ColumnarBatch] = { + new RapidsShuffleThreadedWriter[Int, ColumnarBatch]( + blockManager, shuffleHandle, 0L, conf, + new ThreadSafeShuffleWriteMetricsReporter(taskContext.taskMetrics().shuffleWriteMetrics), + 1024 * 1024, shuffleExecutorComponents, numWriterThreads) + } + + /** + * Verify write results including partition data presence. + * @param partitionsWithData Set of partition IDs that should have data + * @param minWritesPerPartition Optional map specifying minimum write count per partition. + * Used to verify multiple batches wrote to same partition. + */ + private def verifyWrite( + writer: RapidsShuffleThreadedWriter[Int, ColumnarBatch], + expectedRecords: Int, + partitionsWithData: Set[Int], + minWritesPerPartition: Map[Int, Int] = Map.empty): Unit = { + val partitionLengths = writer.getPartitionLengths + val numPartitions = partitionLengths.length + + // Basic checks + assert(partitionLengths.sum === outputFile.length(), + s"Partition lengths sum ${partitionLengths.sum} != file length ${outputFile.length()}") + assert(writer.getBytesInFlight == 0, "bytesInFlight should be 0 after completion") + assert(taskContext.taskMetrics().shuffleWriteMetrics.recordsWritten === expectedRecords, + s"Expected $expectedRecords records, got " + + s"${taskContext.taskMetrics().shuffleWriteMetrics.recordsWritten}") + + // Verify each partition that should have data actually has data + for (partitionId <- partitionsWithData) { + assert(partitionLengths(partitionId) > 0, + s"Partition $partitionId should have data but length is ${partitionLengths(partitionId)}") + } + + // Verify partitions NOT in the set are empty + for (partitionId <- 0 until numPartitions if !partitionsWithData.contains(partitionId)) { + assert(partitionLengths(partitionId) == 0, + s"Partition $partitionId should be empty but length is ${partitionLengths(partitionId)}") + } + + // Verify multiple writes to same partition by checking data length + // Each write to partition P adds at least minBytesPerWrite bytes + // (key int + column count int + buffer size int + buffer data) + val minBytesPerWrite = 4 + 4 + 4 + 64 // at least 76 bytes per record + for ((partitionId, minWrites) <- minWritesPerPartition) { + val expectedMinLength = minWrites * minBytesPerWrite + assert(partitionLengths(partitionId) >= expectedMinLength, + s"Partition $partitionId: expected at least $minWrites writes " + + s"(>= $expectedMinLength bytes), but got ${partitionLengths(partitionId)} bytes. " + + s"This suggests fewer records were written than expected.") + } + } + override def beforeAll(): Unit = { RapidsShuffleInternalManagerBase.startThreadPoolIfNeeded(numWriterThreads, 0) } @@ -197,22 +260,23 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite override def beforeEach(): Unit = { super.beforeEach() + RapidsShuffleInternalManagerBase.startThreadPoolIfNeeded(numWriterThreads, 0) TaskContext.setTaskContext(taskContext) MockitoAnnotations.openMocks(this).close() tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) taskMetrics = spy(new TaskMetrics) val shuffleWriteMetrics = new ShuffleWriteMetrics - shuffleHandle = new ShuffleHandleWithMetrics[Int, Int, Int]( + shuffleHandle = new ShuffleHandleWithMetrics[Int, ColumnarBatch, ColumnarBatch]( 0, Map.empty, dependency) when(dependency.partitioner).thenReturn(new HashPartitioner(7)) - when(dependency.serializer).thenReturn(new JavaSerializer(conf)) - when(dependencyBad.partitioner).thenReturn(new HashPartitioner(7)) - when(dependencyBad.serializer).thenReturn(new JavaSerializer(conf)) + when(dependency.serializer).thenReturn(new TestColumnarBatchSerializer()) when(taskMetrics.shuffleWriteMetrics).thenReturn(shuffleWriteMetrics) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(blockManager.serializerManager) + .thenReturn(new SerializerManager(new TestColumnarBatchSerializer(), conf)) when(blockResolver.writeMetadataFileAndCommit( anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File]))) @@ -226,28 +290,14 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite } when(blockManager.getDiskWriter( - any[BlockId], - any[File], - any[SerializerInstance], - anyInt(), - any[ShuffleWriteMetrics])) + any[BlockId], any[File], any[SerializerInstance], anyInt(), any[ShuffleWriteMetrics])) .thenAnswer { invocation => val args = invocation.getArguments - val manager = new SerializerManager(new JavaSerializer(conf), conf) + val manager = new SerializerManager(new TestColumnarBatchSerializer(), conf) new DiskBlockObjectWriter( - args(1).asInstanceOf[File], - manager, - args(2).asInstanceOf[SerializerInstance], - args(3).asInstanceOf[Int], - syncWrites = false, - args(4).asInstanceOf[ShuffleWriteMetrics], - blockId = args(0).asInstanceOf[BlockId]) - } - - when(blockResolver.createTempFile(any(classOf[File]))) - .thenAnswer { invocationOnMock => - val file = invocationOnMock.getArguments()(0).asInstanceOf[File] - Utils.tempFileWith(file) + args(1).asInstanceOf[File], manager, args(2).asInstanceOf[SerializerInstance], + args(3).asInstanceOf[Int], syncWrites = false, + args(4).asInstanceOf[ShuffleWriteMetrics], blockId = args(0).asInstanceOf[BlockId]) } when(diskBlockManager.createTempShuffleBlock()) @@ -259,234 +309,168 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite (blockId, file) } - when(diskBlockManager.getFile(any[BlockId])).thenAnswer { invocation => - blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId]) - } - - shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( - conf, blockManager, blockResolver) + shuffleExecutorComponents = + new RapidsLocalDiskShuffleExecutorComponents(conf, blockManager, blockResolver) } override def afterEach(): Unit = { TaskContext.unset() blockIdToFileMap.clear() temporaryFilesCreated.clear() - try { - Utils.deleteRecursively(tempDir) - } catch { - case NonFatal(e) => - // Catch non-fatal errors here as we are cleaning up directories using a Spark utility - // and we shouldn't fail a test for these exceptions. See: - // https://github.com/NVIDIA/spark-rapids/issues/6515 - logWarning(s"Error while cleaning up $tempDir", e) - } finally { - super.afterEach() + // Close sliced buffers to release the refCount added by incRefCountAndGetSize + slicedBuffersToClean.foreach { buf => + try { buf.close() } catch { case NonFatal(_) => } } + slicedBuffersToClean.clear() + RapidsShuffleInternalManagerBase.stopThreadPool() + try { Utils.deleteRecursively(tempDir) } catch { case NonFatal(_) => } } + // ==================== Basic Tests ==================== + test("write empty iterator") { - val writer = new RapidsShuffleThreadedWriter[Int, Int]( - blockManager, - shuffleHandle, - 0L, // MapId - conf, - taskContext.taskMetrics().shuffleWriteMetrics, - 1024 * 1024, - shuffleExecutorComponents, - numWriterThreads) + val writer = createWriter() writer.write(Iterator.empty) - writer.stop( /* success = */ true) + writer.stop(true) assert(writer.getPartitionLengths.sum === 0) - assert(writer.getBytesInFlight == 0) - assert(outputFile.exists()) assert(outputFile.length() === 0) - assert(temporaryFilesCreated.isEmpty) - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === 0) - assert(shuffleWriteMetrics.recordsWritten === 0) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) } - Seq(true, false).foreach { transferTo => - test(s"write with some empty partitions - transferTo $transferTo") { - val transferConf = conf.clone.set("spark.file.transferTo", transferTo.toString) - def records: Iterator[(Int, Int)] = - Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - val writer = new RapidsShuffleThreadedWriter[Int, Int]( - blockManager, - shuffleHandle, - 0L, // MapId - transferConf, - new ThreadSafeShuffleWriteMetricsReporter(taskContext.taskMetrics().shuffleWriteMetrics), - 1024 * 1024, - shuffleExecutorComponents, - numWriterThreads) - writer.write(records) - writer.stop( /* success = */ true) - assert(temporaryFilesCreated.nonEmpty) - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files - assert(writer.getBytesInFlight == 0) - assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temp files were deleted - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - } + test("single batch: sequential partitions") { + // Single batch with strictly increasing partition IDs + // Input: 0,1,2,3,4,5,6 -> all in one batch + val writer = createWriter() + writer.write(createTestRecords(Iterator(0, 1, 2, 3, 4, 5, 6))) + writer.stop(true) + verifyWrite(writer, expectedRecords = 7, partitionsWithData = Set(0, 1, 2, 3, 4, 5, 6)) } - test("only generate temp shuffle file for non-empty partition") { - // Using exception to test whether only non-empty partition creates temp shuffle file, - // because temp shuffle file will only be cleaned after calling stop(false) in the failure - // case, so we could use it to validate the temp shuffle files. - def records: Iterator[(Int, Int)] = - Iterator((1, 1), (5, 5)) ++ - (0 until 100000).iterator.map { i => - if (i == 99990) { - throw new SparkException("intentional failure") - } else { - (2, 2) - } - } + // ==================== Multi-batch: Basic Scenarios ==================== + + test("multi-batch: batch2 fills batch1 gaps") { + // Batch1: 1,3,5 (odd partitions) + // Batch2: 0,2,4 (even partitions, triggered by 0 < 5) + // Result: partition 6 empty + val writer = createWriter() + writer.write(createTestRecords(Iterator(1, 3, 5, 0, 2, 4))) + writer.stop(true) + // Both batch1 (1,3,5) and batch2 (0,2,4) data must be present + verifyWrite(writer, expectedRecords = 6, partitionsWithData = Set(0, 1, 2, 3, 4, 5)) + } - val writer = new RapidsShuffleThreadedWriter[Int, Int]( - blockManager, - shuffleHandle, - 0L, // MapId - conf, - taskContext.taskMetrics().shuffleWriteMetrics, - 1024 * 1024, - shuffleExecutorComponents, - numWriterThreads) - - intercept[SparkException] { - writer.write(records) - } + test("multi-batch: extreme jump max to min") { + // Batch1: 6 (max partition only) + // Batch2: 0 (min partition only, triggered by 0 < 6) + // Result: partitions 1-5 empty + val writer = createWriter() + writer.write(createTestRecords(Iterator(6, 0))) + writer.stop(true) + // batch1 has partition 6, batch2 has partition 0 + verifyWrite(writer, expectedRecords = 2, partitionsWithData = Set(0, 6)) + } - assert(temporaryFilesCreated.nonEmpty) - // Only 3 temp shuffle files will be created - assert(temporaryFilesCreated.count(_.exists()) === 3) + // ==================== Multi-batch: Overlap Scenarios ==================== + + test("multi-batch: partitions overlap between batches") { + // Batch1: 1,3,5 + // Batch2: 3,4,5 (triggered by 3 < 5, partitions 3,5 written again) + // Partitions 3,5 have data from BOTH batches + val writer = createWriter() + writer.write(createTestRecords(Iterator(1, 3, 5, 3, 4, 5))) + writer.stop(true) + // batch1 contributes 1,3,5; batch2 contributes 3,4,5 -> union is 1,3,4,5 + // Partitions 3 and 5 should have 2 writes each + verifyWrite(writer, expectedRecords = 6, partitionsWithData = Set(1, 3, 4, 5), + minWritesPerPartition = Map(3 -> 2, 5 -> 2)) + } - writer.stop( /* success = */ false) - assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted - assert(writer.getBytesInFlight == 0) + test("multi-batch: batch2 fully within batch1 range") { + // Batch1: 0,1,2,3,4,5,6 (all partitions) + // Batch2: 2,3,4 (triggered by 2 < 6, subset of batch1) + val writer = createWriter() + writer.write(createTestRecords(Iterator(0, 1, 2, 3, 4, 5, 6, 2, 3, 4))) + writer.stop(true) + // All 7 partitions have data; partitions 2,3,4 have 2 writes each + verifyWrite(writer, expectedRecords = 10, partitionsWithData = Set(0, 1, 2, 3, 4, 5, 6), + minWritesPerPartition = Map(2 -> 2, 3 -> 2, 4 -> 2)) } - test("cleanup of intermediate files after errors") { - val writer = new RapidsShuffleThreadedWriter[Int, Int]( - blockManager, - shuffleHandle, - 0L, // MapId - conf, - taskContext.taskMetrics().shuffleWriteMetrics, - 1024 * 1024, - shuffleExecutorComponents, - numWriterThreads) - intercept[SparkException] { - writer.write((0 until 100000).iterator.map(i => { - if (i == 99990) { - throw new SparkException("Intentional failure") - } - (i, i) - })) - } - assert(temporaryFilesCreated.nonEmpty) - writer.stop( /* success = */ false) - assert(temporaryFilesCreated.count(_.exists()) === 0) - assert(writer.getBytesInFlight == 0) + // ==================== Multi-batch: Repeated Partitions ==================== + + test("single batch: same partition repeated") { + // Consecutive identical partition IDs can occur in two scenarios: + // 1. Reslicing: a large partition is split into multiple smaller batches + // 2. Data skew: multiple GPU batches each containing only the same partition's data + // In both cases, they should be merged into a single shuffle batch (more efficient, + // fewer partial files). This does NOT affect correctness since shuffle write only + // cares about final data completeness per partition. + // Input: 0,0,0,0,0 -> all in one batch + val writer = createWriter() + writer.write(createTestRecords(Iterator(0, 0, 0, 0, 0))) + writer.stop(true) + // Only partition 0 has data, all 5 records in a single batch + // Verify partition 0 was written 5 times + verifyWrite(writer, expectedRecords = 5, partitionsWithData = Set(0), + minWritesPerPartition = Map(0 -> 5)) } - test("write checksum file") { - // this is a spy so we can intercept calls to `createTempShuffleBlock` - // in spark 3.3.0+ - val blockResolver = spy(new TestIndexShuffleBlockResolver(conf, blockManager)) - val shuffleId = shuffleHandle.shuffleId - val mapId = 0 - val checksumBlockId = ShuffleChecksumBlockId(shuffleId, mapId, 0) - val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, 0) - val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, 0) - val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) - val checksumFileName = ShuffleChecksumHelper.getChecksumFileName( - checksumBlockId.name, checksumAlgorithm) - val checksumFile = new File(tempDir, checksumFileName) - val dataFile = new File(tempDir, dataBlockId.name) - val indexFile = new File(tempDir, indexBlockId.name) - reset(diskBlockManager) - when(diskBlockManager.getFile(checksumFileName)).thenAnswer(_ => checksumFile) - when(diskBlockManager.getFile(dataBlockId)).thenAnswer(_ => dataFile) - when(diskBlockManager.getFile(indexBlockId)).thenAnswer(_ => indexFile) - when(diskBlockManager.createTempShuffleBlock()) - .thenAnswer { _ => - val blockId = new TempShuffleBlockId(UUID.randomUUID) - val file = new File(tempDir, blockId.name) - temporaryFilesCreated += file - (blockId, file) - } + test("multi-batch: strictly decreasing creates one batch per record") { + // Input: 5,4,3,2,1,0 + // Each partition ID < previous max, so 6 batches total + // Batch1:5, Batch2:4, Batch3:3, Batch4:2, Batch5:1, Batch6:0 + val writer = createWriter() + writer.write(createTestRecords(Iterator(5, 4, 3, 2, 1, 0))) + writer.stop(true) + // All batches contribute: partitions 0,1,2,3,4,5 have data + verifyWrite(writer, expectedRecords = 6, partitionsWithData = Set(0, 1, 2, 3, 4, 5)) + } - when(blockResolver.createTempFile(any(classOf[File]))) - .thenAnswer { invocationOnMock => - val file = invocationOnMock.getArguments()(0).asInstanceOf[File] - Utils.tempFileWith(file) - } + test("multi-batch: oscillating between two partitions") { + // Input: 2,5,2,5,2,5 + // Batch1: 2,5; Batch2: 2,5; Batch3: 2,5 + val writer = createWriter() + writer.write(createTestRecords(Iterator(2, 5, 2, 5, 2, 5))) + writer.stop(true) + // Only partitions 2 and 5 have data (from all 3 batches) + // Each partition should have 3 writes + verifyWrite(writer, expectedRecords = 6, partitionsWithData = Set(2, 5), + minWritesPerPartition = Map(2 -> 3, 5 -> 3)) + } - val numPartition = shuffleHandle.dependency.partitioner.numPartitions - val writer = new RapidsShuffleThreadedWriter[Int, Int]( - blockManager, - shuffleHandle, - mapId, - conf, - new ThreadSafeShuffleWriteMetricsReporter(taskContext.taskMetrics().shuffleWriteMetrics), - 1024 * 1024, - new LocalDiskShuffleExecutorComponents(conf, blockManager, blockResolver), - numWriterThreads) - - writer.write(Iterator((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6))) - writer.stop( /* success = */ true) - assert(writer.getBytesInFlight == 0) - assert(checksumFile.exists()) - assert(checksumFile.length() === 8 * numPartition) - compareChecksums(numPartition, checksumAlgorithm, checksumFile, dataFile, indexFile) + // ==================== Multi-batch: Size Variations ==================== + + test("multi-batch: batch1 sparse, batch2 full") { + // Batch1: 0,6 (only first and last) + // Batch2: 0,1,2,3,4,5,6 (all partitions, triggered by 0 < 6) + val writer = createWriter() + writer.write(createTestRecords(Iterator(0, 6, 0, 1, 2, 3, 4, 5, 6))) + writer.stop(true) + // batch1 contributes 0,6; batch2 contributes all -> all partitions have data + // Partitions 0 and 6 should have 2 writes each + verifyWrite(writer, expectedRecords = 9, partitionsWithData = Set(0, 1, 2, 3, 4, 5, 6), + minWritesPerPartition = Map(0 -> 2, 6 -> 2)) } - Seq(true, false).foreach { stopWithSuccess => - test(s"create an exception in one of the writers and stop with success = $stopWithSuccess") { - def records: Iterator[(Int, BadSerializable)] = - Iterator( - (1, new BadSerializable(1)), - (5, new BadSerializable(5))) ++ - (10 until 100000).iterator.map(x => (2, new BadSerializable(x))) - - val shuffleHandle = new ShuffleHandleWithMetrics[Int, BadSerializable, BadSerializable]( - 0, - Map.empty, - dependencyBad - ) - val writer = new RapidsShuffleThreadedWriter[Int, BadSerializable]( - blockManager, - shuffleHandle, - 0L, // MapId - conf, - new ThreadSafeShuffleWriteMetricsReporter(taskContext.taskMetrics().shuffleWriteMetrics), - 1024 * 1024, - shuffleExecutorComponents, - numWriterThreads) - assertThrows[IOException] { - writer.write(records) - } - if (stopWithSuccess) { - assertThrows[IllegalStateException] { - writer.stop(true) - } - } else { - writer.stop(false) - } - assert(temporaryFilesCreated.nonEmpty) - assert(writer.getPartitionLengths == null) - assert(writer.getBytesInFlight == 0) - assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temp files were deleted - } + test("multi-batch: batch2 extends beyond batch1 range") { + // Batch1: 2,3 (middle partitions) + // Batch2: 0,1,4,5,6 (triggered by 0 < 3, covers both sides) + val writer = createWriter() + writer.write(createTestRecords(Iterator(2, 3, 0, 1, 4, 5, 6))) + writer.stop(true) + // batch1: 2,3; batch2: 0,1,4,5,6 -> all partitions + verifyWrite(writer, expectedRecords = 7, partitionsWithData = Set(0, 1, 2, 3, 4, 5, 6)) } -} + // ==================== Multi-batch: Three+ Batches ==================== + + test("multi-batch: three batches interleaved") { + // Batch1: 2,4,6 + // Batch2: 1,3,5 (triggered by 1 < 6) + // Batch3: 0 (triggered by 0 < 5) + val writer = createWriter() + writer.write(createTestRecords(Iterator(2, 4, 6, 1, 3, 5, 0))) + writer.stop(true) + // batch1: 2,4,6; batch2: 1,3,5; batch3: 0 -> all partitions + verifyWrite(writer, expectedRecords = 7, partitionsWithData = Set(0, 1, 2, 3, 4, 5, 6)) + } +}