Skip to content

Commit 1594c8d

Browse files
committed
Use move files using NIO if the shuffle dir is mounted as a file system.
Signed-off-by: Pascal Spörri <[email protected]>
1 parent 0b56e2b commit 1594c8d

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

src/main/scala/org/apache/spark/shuffle/S3SingleSpillShuffleMapOutputWriter.scala

+30-5
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55

66
package org.apache.spark.shuffle
77

8+
import org.apache.spark.TaskContext
9+
import org.apache.spark.internal.Logging
810
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter
911
import org.apache.spark.shuffle.helper.{S3ShuffleDispatcher, S3ShuffleHelper}
1012
import org.apache.spark.storage.ShuffleDataBlockId
1113
import org.apache.spark.util.Utils
1214

1315
import java.io.{File, FileInputStream}
16+
import java.nio.file.{Files, Path}
1417

15-
class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) extends SingleSpillShuffleMapOutputWriter {
18+
class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) extends SingleSpillShuffleMapOutputWriter with Logging {
1619

1720
private lazy val dispatcher = S3ShuffleDispatcher.get
1821

@@ -21,12 +24,34 @@ class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) extends S
2124
partitionLengths: Array[Long],
2225
checksums: Array[Long]
2326
): Unit = {
24-
val in = new FileInputStream(mapSpillFile)
2527
val block = ShuffleDataBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
26-
val out = new S3MeasureOutputStream(dispatcher.createBlock(block), block.name)
2728

28-
// Note: HDFS does not exposed a nio-buffered write interface.
29-
Utils.copyStream(in, out, closeStreams = true)
29+
if (dispatcher.rootIsLocal) {
30+
// Use NIO to move the file if the folder is local.
31+
val now = System.nanoTime()
32+
val path = dispatcher.getPath(block)
33+
val fileDestination = path.toUri.getRawPath
34+
val dir = path.getParent
35+
if (!dispatcher.fs.exists(dir)) {
36+
dispatcher.fs.mkdirs(dir)
37+
}
38+
Files.move(mapSpillFile.toPath, Path.of(fileDestination))
39+
val timings = System.nanoTime() - now
40+
41+
val bytes = partitionLengths.sum
42+
val tc = TaskContext.get()
43+
val sId = tc.stageId()
44+
val sAt = tc.stageAttemptNumber()
45+
val t = timings / 1000000
46+
val bw = bytes.toDouble / (t.toDouble / 1000) / (1024 * 1024)
47+
logInfo(s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " +
48+
s"Writing ${block.name} ${bytes} took ${t} ms (${bw} MiB/s)")
49+
} else {
50+
// Copy using a stream.
51+
val in = new FileInputStream(mapSpillFile)
52+
val out = new S3MeasureOutputStream(dispatcher.createBlock(block), block.name)
53+
Utils.copyStream(in, out, closeStreams = true)
54+
}
3055

3156
if (dispatcher.checksumEnabled) {
3257
S3ShuffleHelper.writeChecksum(shuffleId, mapId, checksums)

src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class S3ShuffleDispatcher extends Logging {
4646
}
4747
private val rootDir_ = if (useSparkShuffleFetch) fallbackStoragePath else conf.get("spark.shuffle.s3.rootDir", defaultValue = "sparkS3shuffle/")
4848
val rootDir: String = if (rootDir_.endsWith("/")) rootDir_ else rootDir_ + "/"
49+
val rootIsLocal: Boolean = URI.create(rootDir).getScheme == "file"
4950

5051
// Optional
5152
val bufferSize: Int = conf.getInt("spark.shuffle.s3.bufferSize", defaultValue = 8 * 1024 * 1024)

0 commit comments

Comments
 (0)