diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs
new file mode 100644
index 000000000..8e59694ac
--- /dev/null
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs
@@ -0,0 +1,121 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Threading;
+using Microsoft.Spark.Interop;
+using Microsoft.Spark.Interop.Ipc;
+using Microsoft.Spark.Services;
+using Microsoft.Spark.Sql;
+using Xunit;
+
+namespace Microsoft.Spark.E2ETest.IpcTests
+{
+ [Collection("Spark E2E Tests")]
+ public class JvmThreadPoolGCTests
+ {
+ private readonly ILoggerService _loggerService;
+ private readonly SparkSession _spark;
+ private readonly IJvmBridge _jvmBridge;
+
+ public JvmThreadPoolGCTests(SparkFixture fixture)
+ {
+ _loggerService = LoggerServiceFactory.GetLogger(typeof(JvmThreadPoolGCTests));
+ _spark = fixture.Spark;
+ _jvmBridge = ((IJvmObjectReferenceProvider)_spark).Reference.Jvm;
+ }
+
+ ///
+ /// Test that the active SparkSession is thread-specific.
+ ///
+ [Fact]
+ public void TestThreadLocalSessions()
+ {
+ SparkSession.ClearActiveSession();
+
+ void testChildThread(string appName)
+ {
+ var thread = new Thread(() =>
+ {
+ Assert.Null(SparkSession.GetActiveSession());
+
+ SparkSession.SetActiveSession(
+ SparkSession.Builder().AppName(appName).GetOrCreate());
+
+ // Since we are in the child thread, GetActiveSession() should return the child
+ // SparkSession.
+ SparkSession activeSession = SparkSession.GetActiveSession();
+ Assert.NotNull(activeSession);
+ Assert.Equal(appName, activeSession.Conf().Get("spark.app.name", null));
+ });
+
+ thread.Start();
+ thread.Join();
+ }
+
+ for (int i = 0; i < 5; ++i)
+ {
+ testChildThread(i.ToString());
+ }
+
+ Assert.Null(SparkSession.GetActiveSession());
+ }
+
+ ///
+ /// Monitor a thread via the JvmThreadPoolGC.
+ ///
+ [Fact]
+ public void TestTryAddThread()
+ {
+ using var threadPool = new JvmThreadPoolGC(
+ _loggerService, _jvmBridge, TimeSpan.FromMinutes(30));
+
+ var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
+ thread.Start();
+
+ Assert.True(threadPool.TryAddThread(thread));
+ // Subsequent call should return false, because the thread has already been added.
+ Assert.False(threadPool.TryAddThread(thread));
+
+ thread.Join();
+ }
+
+ ///
+ /// Create a Spark worker thread in the JVM ThreadPool then remove it directly through
+ /// the JvmBridge.
+ ///
+ [Fact]
+ public void TestRmThread()
+ {
+ // Create a thread and ensure that it is initialized in the JVM ThreadPool.
+ var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
+ thread.Start();
+ thread.Join();
+
+ // First call should return true. Second call should return false.
+ Assert.True((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
+ Assert.False((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
+ }
+
+ ///
+ /// Test that the GC interval configuration defaults to 5 minutes, and can be updated
+ /// correctly by setting the environment variable.
+ ///
+ [Fact]
+ public void TestIntervalConfiguration()
+ {
+ // Default value is 5 minutes.
+ Assert.Null(Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL"));
+ Assert.Equal(
+ TimeSpan.FromMinutes(5),
+ SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
+
+ // Test a custom value.
+ Environment.SetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL", "1:30:00");
+ Assert.Equal(
+ TimeSpan.FromMinutes(90),
+ SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
+ }
+ }
+}
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs
index 6cf56617d..cc1542a42 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs
@@ -35,6 +35,10 @@ public void TestSignaturesV2_3_X()
Assert.IsType(SparkSession.Builder());
+ SparkSession.ClearActiveSession();
+ SparkSession.SetActiveSession(_spark);
+ Assert.IsType(SparkSession.GetActiveSession());
+
SparkSession.ClearDefaultSession();
SparkSession.SetDefaultSession(_spark);
Assert.IsType(SparkSession.GetDefaultSession());
@@ -76,7 +80,7 @@ public void TestSignaturesV2_4_X()
///
[Fact]
public void TestCreateDataFrame()
- {
+ {
// Calling CreateDataFrame with schema
{
var data = new List
diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs
index 0fb489a91..1dc53ef13 100644
--- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs
+++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs
@@ -8,6 +8,7 @@
using System.IO;
using System.Net;
using System.Text;
+using System.Threading;
using Microsoft.Spark.Network;
using Microsoft.Spark.Services;
@@ -35,6 +36,7 @@ internal sealed class JvmBridge : IJvmBridge
private readonly ILoggerService _logger =
LoggerServiceFactory.GetLogger(typeof(JvmBridge));
private readonly int _portNumber;
+ private readonly JvmThreadPoolGC _jvmThreadPoolGC;
internal JvmBridge(int portNumber)
{
@@ -45,6 +47,9 @@ internal JvmBridge(int portNumber)
_portNumber = portNumber;
_logger.LogInfo($"JvMBridge port is {portNumber}");
+
+ _jvmThreadPoolGC = new JvmThreadPoolGC(
+ _logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
}
private ISocketWrapper GetConnection()
@@ -158,11 +163,13 @@ private object CallJavaMethod(
ISocketWrapper socket = null;
try
{
+ Thread thread = Thread.CurrentThread;
MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream();
payloadMemoryStream.Position = 0;
PayloadHelper.BuildPayload(
payloadMemoryStream,
isStatic,
+ thread.ManagedThreadId,
classNameOrJvmObjectReference,
methodName,
args);
@@ -176,6 +183,8 @@ private object CallJavaMethod(
(int)payloadMemoryStream.Position);
outputStream.Flush();
+ _jvmThreadPoolGC.TryAddThread(thread);
+
Stream inputStream = socket.InputStream;
int isMethodCallFailed = SerDe.ReadInt32(inputStream);
if (isMethodCallFailed != 0)
@@ -410,6 +419,7 @@ private object ReadCollection(Stream s)
public void Dispose()
{
+ _jvmThreadPoolGC.Dispose();
while (_sockets.TryDequeue(out ISocketWrapper socket))
{
if (socket != null)
diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs
new file mode 100644
index 000000000..0eacebadd
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs
@@ -0,0 +1,149 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Threading;
+using Microsoft.Spark.Services;
+
+namespace Microsoft.Spark.Interop.Ipc
+{
+ ///
+ /// In .NET for Apache Spark, we maintain a 1-to-1 mapping between .NET application threads
+ /// and corresponding JVM threads. When a .NET thread calls a Spark API, that call is executed
+ /// by its corresponding JVM thread. This functionality allows for multithreaded applications
+ /// with thread-local variables.
+ ///
+ /// This class keeps track of the .NET application thread lifecycle. When a .NET application
+ /// thread is no longer alive, this class submits an rmThread command to the JVM backend to
+ /// dispose of its corresponding JVM thread. All methods are thread-safe.
+ ///
+ internal class JvmThreadPoolGC : IDisposable
+ {
+ private readonly ILoggerService _loggerService;
+ private readonly IJvmBridge _jvmBridge;
+ private readonly TimeSpan _threadGCInterval;
+ private readonly ConcurrentDictionary _activeThreads;
+
+ private readonly object _activeThreadGCTimerLock;
+ private Timer _activeThreadGCTimer;
+
+ ///
+ /// Construct the JvmThreadPoolGC.
+ ///
+ /// Logger service.
+ /// The JvmBridge used to call JVM methods.
+ /// The interval to GC finished threads.
+ public JvmThreadPoolGC(ILoggerService loggerService, IJvmBridge jvmBridge, TimeSpan threadGCInterval)
+ {
+ _loggerService = loggerService;
+ _jvmBridge = jvmBridge;
+ _threadGCInterval = threadGCInterval;
+ _activeThreads = new ConcurrentDictionary();
+
+ _activeThreadGCTimerLock = new object();
+ _activeThreadGCTimer = null;
+ }
+
+ ///
+ /// Dispose of the GC timer and run a final round of thread GC.
+ ///
+ public void Dispose()
+ {
+ lock (_activeThreadGCTimerLock)
+ {
+ if (_activeThreadGCTimer != null)
+ {
+ _activeThreadGCTimer.Dispose();
+ _activeThreadGCTimer = null;
+ }
+ }
+
+ GCThreads();
+ }
+
+ ///
+ /// Try to start monitoring a thread.
+ ///
+ /// The thread to add.
+ /// True if success, false if already added.
+ public bool TryAddThread(Thread thread)
+ {
+ bool returnValue = _activeThreads.TryAdd(thread.ManagedThreadId, thread);
+
+ // Initialize the GC timer if necessary.
+ if (_activeThreadGCTimer == null)
+ {
+ lock (_activeThreadGCTimerLock)
+ {
+ if (_activeThreadGCTimer == null && _activeThreads.Count > 0)
+ {
+ _activeThreadGCTimer = new Timer(
+ (state) => GCThreads(),
+ null,
+ _threadGCInterval,
+ _threadGCInterval);
+ }
+ }
+ }
+
+ return returnValue;
+ }
+
+ ///
+ /// Try to remove a thread from the pool. If the removal is successful, then the
+ /// corresponding JVM thread will also be disposed.
+ ///
+ /// The ID of the thread to remove.
+ /// True if success, false if the thread cannot be found.
+ private bool TryDisposeJvmThread(int threadId)
+ {
+ if (_activeThreads.TryRemove(threadId, out _))
+ {
+ // _activeThreads does not have ownership of the threads on the .NET side. This
+ // class does not need to call Join() on the .NET Thread. However, this class is
+ // responsible for sending the rmThread command to the JVM to trigger disposal
+ // of the corresponding JVM thread.
+ if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId))
+ {
+ _loggerService.LogDebug($"GC'd JVM thread {threadId}.");
+ return true;
+ }
+ else
+ {
+ _loggerService.LogWarn(
+ $"rmThread returned false for JVM thread {threadId}. " +
+ $"Either thread does not exist or has already been GC'd.");
+ }
+ }
+
+ return false;
+ }
+
+ ///
+ /// Remove any threads that are no longer active.
+ ///
+ private void GCThreads()
+ {
+ foreach (KeyValuePair kvp in _activeThreads)
+ {
+ if (!kvp.Value.IsAlive)
+ {
+ TryDisposeJvmThread(kvp.Key);
+ }
+ }
+
+ lock (_activeThreadGCTimerLock)
+ {
+ // Dispose of the timer if there are no threads to monitor.
+ if (_activeThreadGCTimer != null && _activeThreads.IsEmpty)
+ {
+ _activeThreadGCTimer.Dispose();
+ _activeThreadGCTimer = null;
+ }
+ }
+ }
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs
index 3d7315b2a..569744713 100644
--- a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs
+++ b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs
@@ -27,7 +27,7 @@ internal class PayloadHelper
private static readonly byte[] s_timestampTypeId = new[] { (byte)'t' };
private static readonly byte[] s_jvmObjectTypeId = new[] { (byte)'j' };
private static readonly byte[] s_byteArrayTypeId = new[] { (byte)'r' };
- private static readonly byte[] s_doubleArrayArrayTypeId = new[] { ( byte)'A' };
+ private static readonly byte[] s_doubleArrayArrayTypeId = new[] { (byte)'A' };
private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' };
private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' };
private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' };
@@ -39,6 +39,7 @@ internal class PayloadHelper
internal static void BuildPayload(
MemoryStream destination,
bool isStaticMethod,
+ int threadId,
object classNameOrJvmObjectReference,
string methodName,
object[] args)
@@ -48,6 +49,7 @@ internal static void BuildPayload(
destination.Position += sizeof(int);
SerDe.Write(destination, isStaticMethod);
+ SerDe.Write(destination, threadId);
SerDe.Write(destination, classNameOrJvmObjectReference.ToString());
SerDe.Write(destination, methodName);
SerDe.Write(destination, args.Length);
@@ -140,7 +142,7 @@ internal static void ConvertArgsToBytes(
SerDe.Write(destination, d);
}
break;
-
+
case double[][] argDoubleArrayArray:
SerDe.Write(destination, s_doubleArrayArrayTypeId);
SerDe.Write(destination, argDoubleArrayArray.Length);
diff --git a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs
index 3b7de1555..4ce565c84 100644
--- a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs
+++ b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs
@@ -33,6 +33,18 @@ internal sealed class ConfigurationService : IConfigurationService
private string _workerPath;
+ ///
+ /// How often to run GC on JVM ThreadPool threads. Defaults to 5 minutes.
+ ///
+ public TimeSpan JvmThreadGCInterval
+ {
+ get
+ {
+ string envVar = Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL");
+ return string.IsNullOrEmpty(envVar) ? TimeSpan.FromMinutes(5) : TimeSpan.Parse(envVar);
+ }
+ }
+
///
/// Returns the port number for socket communication between JVM and CLR.
///
diff --git a/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs b/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs
index 5c7a4074f..5398632bd 100644
--- a/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs
+++ b/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs
@@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System;
+
namespace Microsoft.Spark.Services
{
///
@@ -9,6 +11,11 @@ namespace Microsoft.Spark.Services
///
internal interface IConfigurationService
{
+ ///
+ /// How often to run GC on JVM ThreadPool threads.
+ ///
+ TimeSpan JvmThreadGCInterval { get; }
+
///
/// The port number used for communicating with the .NET backend process.
///
diff --git a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs
index cc2e70e26..c90bc3ce9 100644
--- a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs
+++ b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs
@@ -61,10 +61,40 @@ internal SparkSession(JvmObjectReference jvmObject)
/// Builder object
public static Builder Builder() => new Builder();
- /// Note that *ActiveSession() APIs are not exposed because these APIs work with a
- /// thread-local variable, which stores the session variable. Since the Netty server
- /// that handles the requests is multi-threaded, any thread can invoke these APIs,
- /// resulting in unexpected behaviors if different threads are used.
+ ///
+ /// Changes the SparkSession that will be returned in this thread when
+ /// is called. This can be used to ensure that a given
+ /// thread receives a SparkSession with an isolated session, instead of the global
+ /// (first created) context.
+ ///
+ /// SparkSession object
+ public static void SetActiveSession(SparkSession session) =>
+ session._jvmObject.Jvm.CallStaticJavaMethod(
+ s_sparkSessionClassName, "setActiveSession", session);
+
+ ///
+ /// Clears the active SparkSession for current thread. Subsequent calls to
+ /// will return the first created context
+ /// instead of a thread-local override.
+ ///
+ public static void ClearActiveSession() =>
+ SparkEnvironment.JvmBridge.CallStaticJavaMethod(
+ s_sparkSessionClassName, "clearActiveSession");
+
+ ///
+ /// Returns the active SparkSession for the current thread, returned by the builder.
+ ///
+ /// Return null, when calling this function on executors
+ public static SparkSession GetActiveSession()
+ {
+ var optionalSession = new Option(
+ (JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
+ s_sparkSessionClassName, "getActiveSession"));
+
+ return optionalSession.IsDefined()
+ ? new SparkSession((JvmObjectReference)optionalSession.Get())
+ : null;
+ }
///
/// Sets the default SparkSession that is returned by the builder.
diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
index 1cde1d1c5..e632589e4 100644
--- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
+++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
@@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import scala.collection.mutable.HashMap
+import scala.language.existentials
+
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import org.apache.spark.api.dotnet.SerDe._
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
-import scala.collection.mutable.HashMap
-import scala.language.existentials
-
/**
* Handler for DotnetBackend.
* This implementation is similar to RBackendHandler.
@@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
// First bit is isStatic
val isStatic = readBoolean(dis)
+ val threadId = readInt(dis)
val objId = readString(dis)
val methodName = readString(dis)
val numArgs = readInt(dis)
@@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
logError(s"Removing $objId failed", e)
writeInt(dos, -1)
}
+ case "rmThread" =>
+ try {
+ assert(readObjectType(dis) == 'i')
+ val threadToDelete = readInt(dis)
+ val result = ThreadPool.tryDeleteThread(threadToDelete)
+ writeInt(dos, 0)
+ writeObject(dos, result.asInstanceOf[AnyRef])
+ } catch {
+ case e: Exception =>
+ logError(s"Removing thread $threadId failed", e)
+ writeInt(dos, -1)
+ }
case "connectCallback" =>
assert(readObjectType(dis) == 'c')
val address = readString(dis)
assert(readObjectType(dis) == 'i')
val port = readInt(dis)
- DotnetBackend.setCallbackClient(address, port);
+ DotnetBackend.setCallbackClient(address, port)
writeInt(dos, 0)
writeType(dos, "void")
case "closeCallback" =>
@@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
case _ => dos.writeInt(-1)
}
} else {
- handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
+ ThreadPool
+ .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
}
bos.toByteArray
@@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
"invalid method " + methodName + " for object " + objId)
}
} catch {
- case e: Exception =>
+ case e: Throwable =>
val jvmObj = JVMObjectTracker.get(objId)
val jvmObjName = jvmObj match {
case Some(jObj) => jObj.getClass.getName
case None => "NullObject"
}
- val argsStr = args.map(arg => {
- if (arg != null) {
- s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
- } else {
- "[Value: NULL]"
- }
- }).mkString(", ")
+ val argsStr = args
+ .map(arg => {
+ if (arg != null) {
+ s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
+ } else {
+ "[Value: NULL]"
+ }
+ })
+ .mkString(", ")
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala
new file mode 100644
index 000000000..1888ec746
--- /dev/null
+++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.util.concurrent.{ExecutorService, Executors}
+
+import scala.collection.mutable
+
+/**
+ * Pool of thread executors. There should be a 1-1 correspondence between C# threads
+ * and Java threads.
+ */
+object ThreadPool {
+
+ /**
+ * Map from threadId to corresponding executor.
+ */
+ private val executors: mutable.HashMap[Int, ExecutorService] =
+ new mutable.HashMap[Int, ExecutorService]()
+
+ /**
+ * Run some code on a particular thread.
+ *
+ * @param threadId Integer id of the thread.
+ * @param task Function to run on the thread.
+ */
+ def run(threadId: Int, task: () => Unit): Unit = {
+ val executor = getOrCreateExecutor(threadId)
+ val future = executor.submit(new Runnable {
+ override def run(): Unit = task()
+ })
+
+ future.get()
+ }
+
+ /**
+ * Try to delete a particular thread.
+ *
+ * @param threadId Integer id of the thread.
+ * @return True if successful, false if thread does not exist.
+ */
+ def tryDeleteThread(threadId: Int): Boolean = synchronized {
+ executors.remove(threadId) match {
+ case Some(executorService) =>
+ executorService.shutdown()
+ true
+ case None => false
+ }
+ }
+
+ /**
+ * Get the executor if it exists, otherwise create a new one.
+ *
+ * @param id Integer id of the thread.
+ * @return The new or existing executor with the given id.
+ */
+ private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
+ executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
+ }
+}
diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
index 1cde1d1c5..e632589e4 100644
--- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
+++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
@@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import scala.collection.mutable.HashMap
+import scala.language.existentials
+
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import org.apache.spark.api.dotnet.SerDe._
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
-import scala.collection.mutable.HashMap
-import scala.language.existentials
-
/**
* Handler for DotnetBackend.
* This implementation is similar to RBackendHandler.
@@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
// First bit is isStatic
val isStatic = readBoolean(dis)
+ val threadId = readInt(dis)
val objId = readString(dis)
val methodName = readString(dis)
val numArgs = readInt(dis)
@@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
logError(s"Removing $objId failed", e)
writeInt(dos, -1)
}
+ case "rmThread" =>
+ try {
+ assert(readObjectType(dis) == 'i')
+ val threadToDelete = readInt(dis)
+ val result = ThreadPool.tryDeleteThread(threadToDelete)
+ writeInt(dos, 0)
+ writeObject(dos, result.asInstanceOf[AnyRef])
+ } catch {
+ case e: Exception =>
+ logError(s"Removing thread $threadId failed", e)
+ writeInt(dos, -1)
+ }
case "connectCallback" =>
assert(readObjectType(dis) == 'c')
val address = readString(dis)
assert(readObjectType(dis) == 'i')
val port = readInt(dis)
- DotnetBackend.setCallbackClient(address, port);
+ DotnetBackend.setCallbackClient(address, port)
writeInt(dos, 0)
writeType(dos, "void")
case "closeCallback" =>
@@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
case _ => dos.writeInt(-1)
}
} else {
- handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
+ ThreadPool
+ .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
}
bos.toByteArray
@@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
"invalid method " + methodName + " for object " + objId)
}
} catch {
- case e: Exception =>
+ case e: Throwable =>
val jvmObj = JVMObjectTracker.get(objId)
val jvmObjName = jvmObj match {
case Some(jObj) => jObj.getClass.getName
case None => "NullObject"
}
- val argsStr = args.map(arg => {
- if (arg != null) {
- s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
- } else {
- "[Value: NULL]"
- }
- }).mkString(", ")
+ val argsStr = args
+ .map(arg => {
+ if (arg != null) {
+ s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
+ } else {
+ "[Value: NULL]"
+ }
+ })
+ .mkString(", ")
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala
new file mode 100644
index 000000000..1888ec746
--- /dev/null
+++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.util.concurrent.{ExecutorService, Executors}
+
+import scala.collection.mutable
+
+/**
+ * Pool of thread executors. There should be a 1-1 correspondence between C# threads
+ * and Java threads.
+ */
+object ThreadPool {
+
+ /**
+ * Map from threadId to corresponding executor.
+ */
+ private val executors: mutable.HashMap[Int, ExecutorService] =
+ new mutable.HashMap[Int, ExecutorService]()
+
+ /**
+ * Run some code on a particular thread.
+ *
+ * @param threadId Integer id of the thread.
+ * @param task Function to run on the thread.
+ */
+ def run(threadId: Int, task: () => Unit): Unit = {
+ val executor = getOrCreateExecutor(threadId)
+ val future = executor.submit(new Runnable {
+ override def run(): Unit = task()
+ })
+
+ future.get()
+ }
+
+ /**
+ * Try to delete a particular thread.
+ *
+ * @param threadId Integer id of the thread.
+ * @return True if successful, false if thread does not exist.
+ */
+ def tryDeleteThread(threadId: Int): Boolean = synchronized {
+ executors.remove(threadId) match {
+ case Some(executorService) =>
+ executorService.shutdown()
+ true
+ case None => false
+ }
+ }
+
+ /**
+ * Get the executor if it exists, otherwise create a new one.
+ *
+ * @param id Integer id of the thread.
+ * @return The new or existing executor with the given id.
+ */
+ private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
+ executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
+ }
+}
diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
index 1cde1d1c5..1446e5ff6 100644
--- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
+++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
@@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend)
// First bit is isStatic
val isStatic = readBoolean(dis)
+ val threadId = readInt(dis)
val objId = readString(dis)
val methodName = readString(dis)
val numArgs = readInt(dis)
@@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend)
logError(s"Removing $objId failed", e)
writeInt(dos, -1)
}
+ case "rmThread" =>
+ try {
+ assert(readObjectType(dis) == 'i')
+ val threadToDelete = readInt(dis)
+ val result = ThreadPool.tryDeleteThread(threadToDelete)
+ writeInt(dos, 0)
+ writeObject(dos, result.asInstanceOf[AnyRef])
+ } catch {
+ case e: Exception =>
+ logError(s"Removing thread $threadId failed", e)
+ writeInt(dos, -1)
+ }
case "connectCallback" =>
assert(readObjectType(dis) == 'c')
val address = readString(dis)
assert(readObjectType(dis) == 'i')
val port = readInt(dis)
- DotnetBackend.setCallbackClient(address, port);
+ DotnetBackend.setCallbackClient(address, port)
writeInt(dos, 0)
writeType(dos, "void")
case "closeCallback" =>
@@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend)
case _ => dos.writeInt(-1)
}
} else {
- handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
+ ThreadPool
+ .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
}
bos.toByteArray
@@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend)
"invalid method " + methodName + " for object " + objId)
}
} catch {
- case e: Exception =>
+ case e: Throwable =>
val jvmObj = JVMObjectTracker.get(objId)
val jvmObjName = jvmObj match {
case Some(jObj) => jObj.getClass.getName
case None => "NullObject"
}
- val argsStr = args.map(arg => {
- if (arg != null) {
- s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
- } else {
- "[Value: NULL]"
- }
- }).mkString(", ")
+ val argsStr = args
+ .map(arg => {
+ if (arg != null) {
+ s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
+ } else {
+ "[Value: NULL]"
+ }
+ })
+ .mkString(", ")
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala
new file mode 100644
index 000000000..1888ec746
--- /dev/null
+++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.util.concurrent.{ExecutorService, Executors}
+
+import scala.collection.mutable
+
+/**
+ * Pool of thread executors. There should be a 1-1 correspondence between C# threads
+ * and Java threads.
+ */
+object ThreadPool {
+
+ /**
+ * Map from threadId to corresponding executor.
+ */
+ private val executors: mutable.HashMap[Int, ExecutorService] =
+ new mutable.HashMap[Int, ExecutorService]()
+
+ /**
+ * Run some code on a particular thread.
+ *
+ * @param threadId Integer id of the thread.
+ * @param task Function to run on the thread.
+ */
+ def run(threadId: Int, task: () => Unit): Unit = {
+ val executor = getOrCreateExecutor(threadId)
+ val future = executor.submit(new Runnable {
+ override def run(): Unit = task()
+ })
+
+ future.get()
+ }
+
+ /**
+ * Try to delete a particular thread.
+ *
+ * @param threadId Integer id of the thread.
+ * @return True if successful, false if thread does not exist.
+ */
+ def tryDeleteThread(threadId: Int): Boolean = synchronized {
+ executors.remove(threadId) match {
+ case Some(executorService) =>
+ executorService.shutdown()
+ true
+ case None => false
+ }
+ }
+
+ /**
+ * Get the executor if it exists, otherwise create a new one.
+ *
+ * @param id Integer id of the thread.
+ * @return The new or existing executor with the given id.
+ */
+ private def getOrCreateExecutor(id: Int): ExecutorService = synchronized {
+ executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor)
+ }
+}