5
5
6
6
package org .apache .spark .shuffle
7
7
8
+ import org .apache .spark .TaskContext
9
+ import org .apache .spark .internal .Logging
8
10
import org .apache .spark .shuffle .api .SingleSpillShuffleMapOutputWriter
9
11
import org .apache .spark .shuffle .helper .{S3ShuffleDispatcher , S3ShuffleHelper }
10
12
import org .apache .spark .storage .ShuffleDataBlockId
11
13
import org .apache .spark .util .Utils
12
14
13
15
import java .io .{File , FileInputStream }
16
+ import java .nio .file .{Files , Path }
14
17
15
- class S3SingleSpillShuffleMapOutputWriter (shuffleId : Int , mapId : Long ) extends SingleSpillShuffleMapOutputWriter {
18
+ class S3SingleSpillShuffleMapOutputWriter (shuffleId : Int , mapId : Long ) extends SingleSpillShuffleMapOutputWriter with Logging {
16
19
17
20
private lazy val dispatcher = S3ShuffleDispatcher .get
18
21
@@ -21,12 +24,34 @@ class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) extends S
21
24
partitionLengths : Array [Long ],
22
25
checksums : Array [Long ]
23
26
): Unit = {
24
- val in = new FileInputStream (mapSpillFile)
25
27
val block = ShuffleDataBlockId (shuffleId, mapId, IndexShuffleBlockResolver .NOOP_REDUCE_ID )
26
- val out = new S3MeasureOutputStream (dispatcher.createBlock(block), block.name)
27
28
28
- // Note: HDFS does not exposed a nio-buffered write interface.
29
- Utils .copyStream(in, out, closeStreams = true )
29
+ if (dispatcher.rootIsLocal) {
30
+ // Use NIO to move the file if the folder is local.
31
+ val now = System .nanoTime()
32
+ val path = dispatcher.getPath(block)
33
+ val fileDestination = path.toUri.getRawPath
34
+ val dir = path.getParent
35
+ if (! dispatcher.fs.exists(dir)) {
36
+ dispatcher.fs.mkdirs(dir)
37
+ }
38
+ Files .move(mapSpillFile.toPath, Path .of(fileDestination))
39
+ val timings = System .nanoTime() - now
40
+
41
+ val bytes = partitionLengths.sum
42
+ val tc = TaskContext .get()
43
+ val sId = tc.stageId()
44
+ val sAt = tc.stageAttemptNumber()
45
+ val t = timings / 1000000
46
+ val bw = bytes.toDouble / (t.toDouble / 1000 ) / (1024 * 1024 )
47
+ logInfo(s " Statistics: Stage ${sId}. ${sAt} TID ${tc.taskAttemptId()} -- " +
48
+ s " Writing ${block.name} ${bytes} took ${t} ms ( ${bw} MiB/s) " )
49
+ } else {
50
+ // Copy using a stream.
51
+ val in = new FileInputStream (mapSpillFile)
52
+ val out = new S3MeasureOutputStream (dispatcher.createBlock(block), block.name)
53
+ Utils .copyStream(in, out, closeStreams = true )
54
+ }
30
55
31
56
if (dispatcher.checksumEnabled) {
32
57
S3ShuffleHelper .writeChecksum(shuffleId, mapId, checksums)
0 commit comments