Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate support for scalafmt. #87

Merged
merged 4 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
- main
tags:
- v*

jobs:
Build:
strategy:
Expand Down Expand Up @@ -66,6 +66,12 @@ jobs:
distribution: temurin
java-version: 11
cache: sbt
- name: Check formatting
shell: bash
run: |
echo "If either of these checks fail run: 'sbt scalafmtAll && sbt scalafmtSbt'"
sbt scalafmtSbtCheck
sbt scalafmtCheckAll
- name: Test Default Shuffle Fetch
shell: bash
if: startsWith(matrix.scala, '2.12.')
Expand Down
3 changes: 3 additions & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
version = 3.7.13
runner.dialect = scala212
maxColumn = 120
28 changes: 19 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@ Examples are available [here](./examples).
```bash
sbt package # Creates a minimal jar.
sbt assembly # Creates the full assembly with all dependencies, notably hadoop cloud.
```
```

## Formatting Code

Formatting is done with `scalafmt`. This can be triggered with the following configuration.

```bash
sbt scalafmtAll # Format the source code
sbt scalafmtSbt # Format the SBT.
```


## Required configuration

Expand Down Expand Up @@ -86,14 +96,14 @@ to Java > 11:
--add-opens=java.base/java.io=ALL-UNNAMED
--add-opens=java.base/java.net=ALL-UNNAMED
--add-opens=java.base/java.nio=ALL-UNNAMED
--add-opens=java.base/java.util=ALL-UNNAMED
--add-opens=java.base/java.util.concurrent=ALL-UNNAMED
--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED
--add-opens=java.base/sun.nio.ch=ALL-UNNAMED
--add-opens=java.base/sun.nio.cs=ALL-UNNAMED
--add-opens=java.base/java.util=ALL-UNNAMED
--add-opens=java.base/java.util.concurrent=ALL-UNNAMED
--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED
--add-opens=java.base/sun.nio.ch=ALL-UNNAMED
--add-opens=java.base/sun.nio.cs=ALL-UNNAMED
--add-opens=java.base/sun.security.action=ALL-UNNAMED -
-add-opens=java.base/sun.util.calendar=ALL-UNNAMED
--add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED
-add-opens=java.base/sun.util.calendar=ALL-UNNAMED
--add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED
```

## Usage
Expand All @@ -119,7 +129,7 @@ Add the following lines to your Spark configuration:
--conf spark.hadoop.fs.s3a.endpoint=S3A_ENDPOINT
--conf spark.hadoop.fs.s3a.path.style.access=true
--conf spark.hadoop.fs.s3a.fast.upload=true

--conf spark.shuffle.manager="org.apache.spark.shuffle.sort.S3ShuffleManager"
--conf spark.shuffle.sort.io.plugin.class="org.apache.spark.shuffle.S3ShuffleDataIO"
--conf spark.hadoop.fs.s3a.impl="org.apache.hadoop.fs.s3a.S3AFileSystem"
Expand Down
33 changes: 17 additions & 16 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,25 @@ buildInfoKeys ++= Seq[BuildInfoKey](
BuildInfoKey.action("sparkVersion") {
sparkVersion
}
)
)

libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % sparkVersion % "provided",
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
"org.apache.spark" %% "spark-hadoop-cloud" % sparkVersion % "compile",
)

libraryDependencies ++= (if (scalaBinaryVersion.value == "2.12") Seq(
"junit" % "junit" % "4.13.2" % Test,
"org.scalatest" %% "scalatest" % "3.2.2" % Test,
"ch.cern.sparkmeasure" %% "spark-measure" % "0.18" % Test,
"org.scalacheck" %% "scalacheck" % "1.15.2" % Test,
"org.mockito" % "mockito-core" % "3.4.6" % Test,
"org.scalatestplus" %% "mockito-3-4" % "3.2.9.0" % Test,
"com.github.sbt" % "junit-interface" % "0.13.3" % Test
)
else Seq())
"org.apache.spark" %% "spark-hadoop-cloud" % sparkVersion % "compile"
)

libraryDependencies ++= (if (scalaBinaryVersion.value == "2.12")
Seq(
"junit" % "junit" % "4.13.2" % Test,
"org.scalatest" %% "scalatest" % "3.2.2" % Test,
"ch.cern.sparkmeasure" %% "spark-measure" % "0.18" % Test,
"org.scalacheck" %% "scalacheck" % "1.15.2" % Test,
"org.mockito" % "mockito-core" % "3.4.6" % Test,
"org.scalatestplus" %% "mockito-3-4" % "3.2.9.0" % Test,
"com.github.sbt" % "junit-interface" % "0.13.3" % Test
)
else Seq())

javacOptions ++= Seq("-source", "1.8", "-target", "1.8")
javaOptions ++= Seq("-Xms512M", "-Xmx2048M", "-XX:MaxPermSize=2048M", "-XX:+CMSClassUnloadingEnabled")
Expand All @@ -52,8 +53,8 @@ artifactName := { (sv: ScalaVersion, module: ModuleID, artifact: Artifact) =>
}

assemblyMergeStrategy := {
case PathList("META-INF", xs@_*) => MergeStrategy.discard
case x => MergeStrategy.first
case PathList("META-INF", xs @ _*) => MergeStrategy.discard
case x => MergeStrategy.first
}
assembly / assemblyJarName := s"${name.value}_${scalaBinaryVersion.value}-${sparkVersion}_${version}-with-dependencies.jar"
assembly / assemblyOption ~= {
Expand Down
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "1.2.0") // https://github.com/sbt/sbt-assembly (MIT)
addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.11.0") // https://github.com/sbt/sbt-buildinfo (MIT)
addSbtPlugin("com.github.sbt" % "sbt-git" % "2.0.1") // https://github.com/sbt/sbt-git (BSD-2-Clause)
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2")
26 changes: 15 additions & 11 deletions src/main/scala/org/apache/spark/shuffle/ConcurrentObjectMap.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Copyright 2022- IBM Inc. All rights reserved
* SPDX-License-Identifier: Apache2.0
*/
//
// Copyright 2022- IBM Inc. All rights reserved
// SPDX-License-Identifier: Apache 2.0
//

package org.apache.spark.shuffle

Expand All @@ -20,13 +20,17 @@ class ConcurrentObjectMap[K, V] {
}

def getOrElsePut(key: K, op: K => V): V = {
val l = valueLocks.get(key).getOrElse({
lock.synchronized {
valueLocks.getOrElseUpdate(key, {
new Object()
})
}
})
val l = valueLocks
.get(key)
.getOrElse({
lock.synchronized {
valueLocks.getOrElseUpdate(
key, {
new Object()
}
)
}
})
l.synchronized {
return map.getOrElseUpdate(key, op(key))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class S3MeasureOutputStream(var out: OutputStream, label: String = "") extends O
private var timings: Long = 0
private var bytes: Long = 0


private def checkOpen(): Unit = {
if (!isOpen) {
throw new IOException("The stream is already closed!")
Expand Down Expand Up @@ -58,7 +57,9 @@ class S3MeasureOutputStream(var out: OutputStream, label: String = "") extends O
val sAt = tc.stageAttemptNumber()
val t = timings / 1000000
val bw = bytes.toDouble / (t.toDouble / 1000) / (1024 * 1024)
logInfo(s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " +
s"Writing ${label} ${bytes} took ${t} ms (${bw} MiB/s)")
logInfo(
s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " +
s"Writing ${label} ${bytes} took ${t} ms (${bw} MiB/s)"
)
}
}
15 changes: 7 additions & 8 deletions src/main/scala/org/apache/spark/shuffle/S3ShuffleDataIO.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Copyright 2022- IBM Inc. All rights reserved
* SPDX-License-Identifier: Apache2.0
*/
//
// Copyright 2022- IBM Inc. All rights reserved
// SPDX-License-Identifier: Apache 2.0
//

package org.apache.spark.shuffle

Expand Down Expand Up @@ -36,9 +36,9 @@ class S3ShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO {
}

override def createSingleFileMapOutputWriter(
shuffleId: Int,
mapId: Long
): Optional[SingleSpillShuffleMapOutputWriter] = {
shuffleId: Int,
mapId: Long
): Optional[SingleSpillShuffleMapOutputWriter] = {
Optional.of(new S3SingleSpillShuffleMapOutputWriter(shuffleId, mapId))
}
}
Expand Down Expand Up @@ -67,4 +67,3 @@ class S3ShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO {
}
}
}

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Copyright 2022- IBM Inc. All rights reserved
* SPDX-License-Identifier: Apache2.0
*/
//
// Copyright 2022- IBM Inc. All rights reserved
// SPDX-License-Identifier: Apache 2.0
//

package org.apache.spark.shuffle

Expand All @@ -19,19 +19,18 @@ import java.nio.ByteBuffer
import java.nio.channels.{Channels, WritableByteChannel}
import java.util.Optional

/**
* Implements the ShuffleMapOutputWriter interface. It stores the shuffle output in one
* shuffle block.
*
* This file is based on Spark "LocalDiskShuffleMapOutputWriter.java".
*/
/** Implements the ShuffleMapOutputWriter interface. It stores the shuffle output in one shuffle block.
*
* This file is based on Spark "LocalDiskShuffleMapOutputWriter.java".
*/

class S3ShuffleMapOutputWriter(
conf: SparkConf,
shuffleId: Int,
mapId: Long,
numPartitions: Int,
) extends ShuffleMapOutputWriter with Logging {
conf: SparkConf,
shuffleId: Int,
mapId: Long,
numPartitions: Int
) extends ShuffleMapOutputWriter
with Logging {
val dispatcher = S3ShuffleDispatcher.get

/* Target block for writing */
Expand All @@ -44,7 +43,8 @@ class S3ShuffleMapOutputWriter(
def initStream(): Unit = {
if (stream == null) {
stream = dispatcher.createBlock(shuffleBlock)
bufferedStream = new S3MeasureOutputStream(new BufferedOutputStream(stream, dispatcher.bufferSize), shuffleBlock.name)
bufferedStream =
new S3MeasureOutputStream(new BufferedOutputStream(stream, dispatcher.bufferSize), shuffleBlock.name)
}
}

Expand All @@ -59,10 +59,11 @@ class S3ShuffleMapOutputWriter(
private var totalBytesWritten: Long = 0
private var lastPartitionWriterId: Int = -1

/**
* @param reducePartitionId Monotonically increasing, as per contract in ShuffleMapOutputWriter.
* @return An instance of the ShufflePartitionWriter exposing the single output stream.
*/
/** @param reducePartitionId
* Monotonically increasing, as per contract in ShuffleMapOutputWriter.
* @return
* An instance of the ShufflePartitionWriter exposing the single output stream.
*/
override def getPartitionWriter(reducePartitionId: Int): ShufflePartitionWriter = {
if (reducePartitionId <= lastPartitionWriterId) {
throw new RuntimeException("Precondition: Expect a monotonically increasing reducePartitionId.")
Expand All @@ -81,19 +82,21 @@ class S3ShuffleMapOutputWriter(
new S3ShufflePartitionWriter(reducePartitionId)
}

/**
* Close all writers and the shuffle block.
*
* @param checksums Ignored.
* @return
*/
/** Close all writers and the shuffle block.
*
* @param checksums
* Ignored.
* @return
*/
override def commitAllPartitions(checksums: Array[Long]): MapOutputCommitMessage = {
if (bufferedStream != null) {
bufferedStream.flush()
}
if (stream != null) {
if (stream.getPos != totalBytesWritten) {
throw new RuntimeException(f"S3ShuffleMapOutputWriter: Unexpected output length ${stream.getPos}, expected: ${totalBytesWritten}.")
throw new RuntimeException(
f"S3ShuffleMapOutputWriter: Unexpected output length ${stream.getPos}, expected: ${totalBytesWritten}."
)
}
}
if (bufferedStreamAsChannel != null) {
Expand Down Expand Up @@ -198,8 +201,7 @@ class S3ShuffleMapOutputWriter(
}
}

private class S3ShufflePartitionWriterChannel(reduceId: Int)
extends WritableByteChannelWrapper {
private class S3ShufflePartitionWriterChannel(reduceId: Int) extends WritableByteChannelWrapper {
private val partChannel = new S3PartitionWritableByteChannel(bufferedStreamAsChannel)

override def channel(): WritableByteChannel = {
Expand All @@ -216,8 +218,7 @@ class S3ShuffleMapOutputWriter(
}
}

private class S3PartitionWritableByteChannel(channel: WritableByteChannel)
extends WritableByteChannel {
private class S3PartitionWritableByteChannel(channel: WritableByteChannel) extends WritableByteChannel {

private var count: Long = 0

Expand All @@ -229,8 +230,7 @@ class S3ShuffleMapOutputWriter(
channel.isOpen()
}

override def close(): Unit = {
}
override def close(): Unit = {}

override def write(x: ByteBuffer): Int = {
var c = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@ class S3ShuffleWriter[K, V](writer: ShuffleWriter[K, V]) extends ShuffleWriter[K

override def getPartitionLengths(): Array[Long] = writer.getPartitionLengths()
}

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Copyright 2023- IBM Inc. All rights reserved
* SPDX-License-Identifier: Apache2.0
*/
//
// Copyright 2023- IBM Inc. All rights reserved
// SPDX-License-Identifier: Apache 2.0
//

package org.apache.spark.shuffle

Expand All @@ -15,15 +15,17 @@ import org.apache.spark.util.Utils
import java.io.{File, FileInputStream}
import java.nio.file.{Files, Path}

class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) extends SingleSpillShuffleMapOutputWriter with Logging {
class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long)
extends SingleSpillShuffleMapOutputWriter
with Logging {

private lazy val dispatcher = S3ShuffleDispatcher.get

override def transferMapSpillFile(
mapSpillFile: File,
partitionLengths: Array[Long],
checksums: Array[Long]
): Unit = {
mapSpillFile: File,
partitionLengths: Array[Long],
checksums: Array[Long]
): Unit = {
val block = ShuffleDataBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)

if (dispatcher.rootIsLocal) {
Expand All @@ -44,8 +46,10 @@ class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) extends S
val sAt = tc.stageAttemptNumber()
val t = timings / 1000000
val bw = bytes.toDouble / (t.toDouble / 1000) / (1024 * 1024)
logInfo(s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " +
s"Writing ${block.name} ${bytes} took ${t} ms (${bw} MiB/s)")
logInfo(
s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " +
s"Writing ${block.name} ${bytes} took ${t} ms (${bw} MiB/s)"
)
} else {
// Copy using a stream.
val in = new FileInputStream(mapSpillFile)
Expand Down
Loading
Loading