Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 49 additions & 45 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private[spark] class TaskContextImpl(
@transient private val onInterruptCallbacks = new Stack[TaskInterruptListener]

/**
* The thread currently executing task completion, failure, or interrupt listeners, if any.
* The thread currently executing task completion or failure listeners, if any.
*
* `invokeListeners()` uses this to ensure listeners are called sequentially.
*/
Expand Down Expand Up @@ -131,12 +131,11 @@ private[spark] class TaskContextImpl(

override def addTaskInterruptListener(listener: TaskInterruptListener): this.type = {
synchronized {
// If there is already a thread invoking listeners, adding the new listener to
// `onInterruptCallbacks` will cause that thread to execute the new listener, and the call to
// `invokeTaskInterruptListeners()` below will be a no-op.
// If another thread is already running `invokeTaskInterruptListeners`, adding the new
// listener to `onInterruptCallbacks` will cause that thread to execute it (the loop pops
// listeners under the TaskContext lock).
//
// If there is no such thread, the call to `invokeTaskInterruptListeners()` below will execute
// all listeners, including the new listener.
// Otherwise, `invokeTaskInterruptListeners()` below will execute all listeners.
onInterruptCallbacks.push(listener)
reasonIfKilled
}.foreach { reason =>
Expand Down Expand Up @@ -172,47 +171,58 @@ private[spark] class TaskContextImpl(
private def invokeTaskCompletionListeners(error: Option[Throwable]): Unit = {
// It is safe to access the reference to `onCompleteCallbacks` without holding the TaskContext
// lock. `invokeListeners()` acquires the lock before accessing the contents.
invokeListeners(
onCompleteCallbacks,
"TaskCompletionListener",
error,
markTaskFailedOnListenerError = true) {
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) {
_.onTaskCompletion(this)
}
}

private def invokeTaskFailureListeners(error: Throwable): Unit = {
// It is safe to access the reference to `onFailureCallbacks` without holding the TaskContext
// lock. `invokeListeners()` acquires the lock before accessing the contents.
invokeListeners(
onFailureCallbacks,
"TaskFailureListener",
Option(error),
markTaskFailedOnListenerError = true) {
invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) {
_.onTaskFailure(this, error)
}
}

private def invokeTaskInterruptListeners(reason: String, error: Throwable): Unit = {
// It is safe to access the reference to `onInterruptCallbacks` without holding the TaskContext
// lock. `invokeListeners()` acquires the lock before accessing the contents.
// Do not call `markTaskFailed` per listener here: the first failure would win in
// `failureCauseOpt` and mask the aggregate `TaskCompletionListenerException` that
// `markInterrupted` records after swallowing the thrown exception (SPARK-56330).
invokeListeners(
onInterruptCallbacks,
"TaskInterruptListener",
Option(error),
markTaskFailedOnListenerError = false) {
_.onTaskInterrupted(this, reason)
// Do not use `invokeListeners()`. That method uses `listenerInvocationThread` to serialize
// all listener invocations, which would prevent task completion or failure listeners from
// running if the task completes or fails while executing an interrupt listener (causing
// resource leaks such as task-managed resources not being freed).
//
// Instead, directly execute the task interrupt listeners with independent serialization.
// Exceptions are collected and thrown as a TaskCompletionListenerException so that
// `markInterrupted` can catch and record the aggregate failure (SPARK-56330).
def getNextListener(): Option[TaskInterruptListener] = synchronized {
if (onInterruptCallbacks.empty()) {
None
} else {
Some(onInterruptCallbacks.pop())
}
}
val listenerExceptions = new ArrayBuffer[Throwable](2)
var listenerOption: Option[TaskInterruptListener] = None
while ({listenerOption = getNextListener(); listenerOption.nonEmpty}) {
try {
listenerOption.get.onTaskInterrupted(this, reason)
} catch {
case e: Throwable =>
listenerExceptions += e
logError(log"Error in TaskInterruptListener", e)
}
}
if (listenerExceptions.nonEmpty) {
val exception = new TaskCompletionListenerException(
listenerExceptions.map(_.getMessage).toSeq, Option(error))
listenerExceptions.foreach(exception.addSuppressed)
throw exception
}
}

private def invokeListeners[T](
listeners: Stack[T],
name: String,
error: Option[Throwable],
markTaskFailedOnListenerError: Boolean)(
error: Option[Throwable])(
callback: T => Unit): Unit = {
// This method is subject to two constraints:
//
Expand Down Expand Up @@ -255,8 +265,7 @@ private[spark] class TaskContextImpl(
callback(listener)
} catch {
case e: Throwable =>
// A listener failed. For completion/failure listeners, temporarily clear
// listenerInvocationThread and markTaskFailed so nested TaskContext calls can run.
// A listener failed. Temporarily clear the listenerInvocationThread and markTaskFailed.
//
// One of the following cases applies (#3 being the interesting one):
//
Expand Down Expand Up @@ -290,20 +299,15 @@ private[spark] class TaskContextImpl(
// failed, and now another completion listener has failed. Then our call to
// [[markTaskFailed]] here will have no effect and we simply resume running the
// remaining completion handlers.
//
// Task interrupt listeners skip per-listener [[markTaskFailed]]; see
// [[invokeTaskInterruptListeners]].
if (markTaskFailedOnListenerError) {
try {
listenerInvocationThread = None
markTaskFailed(e)
} catch {
case t: Throwable => e.addSuppressed(t)
} finally {
synchronized {
if (listenerInvocationThread.isEmpty) {
listenerInvocationThread = Some(Thread.currentThread())
}
try {
listenerInvocationThread = None
markTaskFailed(e)
} catch {
case t: Throwable => e.addSuppressed(t)
} finally {
synchronized {
if (listenerInvocationThread.isEmpty) {
listenerInvocationThread = Some(Thread.currentThread())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,41 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(invocations == 2)
}

test("SPARK-56330: task completion during interrupt listener execution") {
val context = TaskContext.empty()
val completionListener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(completionListener)

// Add a task interrupt listener that blocks until released.
val interruptListenerStarted = new Semaphore(0)
val interruptListenerRelease = new Semaphore(0)
context.addTaskInterruptListener(new TaskInterruptListener {
override def onTaskInterrupted(context: TaskContext, reason: String): Unit = {
interruptListenerStarted.release()
interruptListenerRelease.acquire()
}
})

// Interrupt the task from a separate thread and wait until the interrupt listener starts.
val interruptThread = new Thread(() => context.markInterrupted("test interrupt"))
interruptThread.start()
interruptListenerStarted.acquire()

// While the interrupt listener is running on the interrupt thread, mark the task completed
// on this thread. With the dedicated interrupt listener loop, this must NOT be blocked.
context.markTaskCompleted(None)

// The completion listener should have been called even though the interrupt listener is still
// running. If `invokeListeners()` were shared between interrupt and completion listeners,
// the completion listener would be silently skipped because `listenerInvocationThread` would
// be held by the interrupt thread.
verify(completionListener, times(1)).onTaskCompletion(any())

// Release the interrupt listener and join the interrupt thread.
interruptListenerRelease.release()
interruptThread.join()
}

test("FailureListener throws after task body fails") {
val context = TaskContext.empty()
val listenerCalls = ArrayBuffer.empty[String]
Expand Down