diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs new file mode 100644 index 000000000..8e59694ac --- /dev/null +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs @@ -0,0 +1,121 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; +using Microsoft.Spark.Interop; +using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Services; +using Microsoft.Spark.Sql; +using Xunit; + +namespace Microsoft.Spark.E2ETest.IpcTests +{ + [Collection("Spark E2E Tests")] + public class JvmThreadPoolGCTests + { + private readonly ILoggerService _loggerService; + private readonly SparkSession _spark; + private readonly IJvmBridge _jvmBridge; + + public JvmThreadPoolGCTests(SparkFixture fixture) + { + _loggerService = LoggerServiceFactory.GetLogger(typeof(JvmThreadPoolGCTests)); + _spark = fixture.Spark; + _jvmBridge = ((IJvmObjectReferenceProvider)_spark).Reference.Jvm; + } + + /// + /// Test that the active SparkSession is thread-specific. + /// + [Fact] + public void TestThreadLocalSessions() + { + SparkSession.ClearActiveSession(); + + void testChildThread(string appName) + { + var thread = new Thread(() => + { + Assert.Null(SparkSession.GetActiveSession()); + + SparkSession.SetActiveSession( + SparkSession.Builder().AppName(appName).GetOrCreate()); + + // Since we are in the child thread, GetActiveSession() should return the child + // SparkSession. + SparkSession activeSession = SparkSession.GetActiveSession(); + Assert.NotNull(activeSession); + Assert.Equal(appName, activeSession.Conf().Get("spark.app.name", null)); + }); + + thread.Start(); + thread.Join(); + } + + for (int i = 0; i < 5; ++i) + { + testChildThread(i.ToString()); + } + + Assert.Null(SparkSession.GetActiveSession()); + } + + /// + /// Monitor a thread via the JvmThreadPoolGC. + /// + [Fact] + public void TestTryAddThread() + { + using var threadPool = new JvmThreadPoolGC( + _loggerService, _jvmBridge, TimeSpan.FromMinutes(30)); + + var thread = new Thread(() => _spark.Sql("SELECT TRUE")); + thread.Start(); + + Assert.True(threadPool.TryAddThread(thread)); + // Subsequent call should return false, because the thread has already been added. + Assert.False(threadPool.TryAddThread(thread)); + + thread.Join(); + } + + /// + /// Create a Spark worker thread in the JVM ThreadPool then remove it directly through + /// the JvmBridge. + /// + [Fact] + public void TestRmThread() + { + // Create a thread and ensure that it is initialized in the JVM ThreadPool. + var thread = new Thread(() => _spark.Sql("SELECT TRUE")); + thread.Start(); + thread.Join(); + + // First call should return true. Second call should return false. + Assert.True((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId)); + Assert.False((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId)); + } + + /// + /// Test that the GC interval configuration defaults to 5 minutes, and can be updated + /// correctly by setting the environment variable. + /// + [Fact] + public void TestIntervalConfiguration() + { + // Default value is 5 minutes. + Assert.Null(Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL")); + Assert.Equal( + TimeSpan.FromMinutes(5), + SparkEnvironment.ConfigurationService.JvmThreadGCInterval); + + // Test a custom value. + Environment.SetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL", "1:30:00"); + Assert.Equal( + TimeSpan.FromMinutes(90), + SparkEnvironment.ConfigurationService.JvmThreadGCInterval); + } + } +} diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs index 6cf56617d..cc1542a42 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs @@ -35,6 +35,10 @@ public void TestSignaturesV2_3_X() Assert.IsType(SparkSession.Builder()); + SparkSession.ClearActiveSession(); + SparkSession.SetActiveSession(_spark); + Assert.IsType(SparkSession.GetActiveSession()); + SparkSession.ClearDefaultSession(); SparkSession.SetDefaultSession(_spark); Assert.IsType(SparkSession.GetDefaultSession()); @@ -76,7 +80,7 @@ public void TestSignaturesV2_4_X() /// [Fact] public void TestCreateDataFrame() - { + { // Calling CreateDataFrame with schema { var data = new List diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index 0fb489a91..1dc53ef13 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -8,6 +8,7 @@ using System.IO; using System.Net; using System.Text; +using System.Threading; using Microsoft.Spark.Network; using Microsoft.Spark.Services; @@ -35,6 +36,7 @@ internal sealed class JvmBridge : IJvmBridge private readonly ILoggerService _logger = LoggerServiceFactory.GetLogger(typeof(JvmBridge)); private readonly int _portNumber; + private readonly JvmThreadPoolGC _jvmThreadPoolGC; internal JvmBridge(int portNumber) { @@ -45,6 +47,9 @@ internal JvmBridge(int portNumber) _portNumber = portNumber; _logger.LogInfo($"JvMBridge port is {portNumber}"); + + _jvmThreadPoolGC = new JvmThreadPoolGC( + _logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval); } private ISocketWrapper GetConnection() @@ -158,11 +163,13 @@ private object CallJavaMethod( ISocketWrapper socket = null; try { + Thread thread = Thread.CurrentThread; MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream(); payloadMemoryStream.Position = 0; PayloadHelper.BuildPayload( payloadMemoryStream, isStatic, + thread.ManagedThreadId, classNameOrJvmObjectReference, methodName, args); @@ -176,6 +183,8 @@ private object CallJavaMethod( (int)payloadMemoryStream.Position); outputStream.Flush(); + _jvmThreadPoolGC.TryAddThread(thread); + Stream inputStream = socket.InputStream; int isMethodCallFailed = SerDe.ReadInt32(inputStream); if (isMethodCallFailed != 0) @@ -410,6 +419,7 @@ private object ReadCollection(Stream s) public void Dispose() { + _jvmThreadPoolGC.Dispose(); while (_sockets.TryDequeue(out ISocketWrapper socket)) { if (socket != null) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs new file mode 100644 index 000000000..0eacebadd --- /dev/null +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs @@ -0,0 +1,149 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; +using Microsoft.Spark.Services; + +namespace Microsoft.Spark.Interop.Ipc +{ + /// + /// In .NET for Apache Spark, we maintain a 1-to-1 mapping between .NET application threads + /// and corresponding JVM threads. When a .NET thread calls a Spark API, that call is executed + /// by its corresponding JVM thread. This functionality allows for multithreaded applications + /// with thread-local variables. + /// + /// This class keeps track of the .NET application thread lifecycle. When a .NET application + /// thread is no longer alive, this class submits an rmThread command to the JVM backend to + /// dispose of its corresponding JVM thread. All methods are thread-safe. + /// + internal class JvmThreadPoolGC : IDisposable + { + private readonly ILoggerService _loggerService; + private readonly IJvmBridge _jvmBridge; + private readonly TimeSpan _threadGCInterval; + private readonly ConcurrentDictionary _activeThreads; + + private readonly object _activeThreadGCTimerLock; + private Timer _activeThreadGCTimer; + + /// + /// Construct the JvmThreadPoolGC. + /// + /// Logger service. + /// The JvmBridge used to call JVM methods. + /// The interval to GC finished threads. + public JvmThreadPoolGC(ILoggerService loggerService, IJvmBridge jvmBridge, TimeSpan threadGCInterval) + { + _loggerService = loggerService; + _jvmBridge = jvmBridge; + _threadGCInterval = threadGCInterval; + _activeThreads = new ConcurrentDictionary(); + + _activeThreadGCTimerLock = new object(); + _activeThreadGCTimer = null; + } + + /// + /// Dispose of the GC timer and run a final round of thread GC. + /// + public void Dispose() + { + lock (_activeThreadGCTimerLock) + { + if (_activeThreadGCTimer != null) + { + _activeThreadGCTimer.Dispose(); + _activeThreadGCTimer = null; + } + } + + GCThreads(); + } + + /// + /// Try to start monitoring a thread. + /// + /// The thread to add. + /// True if success, false if already added. + public bool TryAddThread(Thread thread) + { + bool returnValue = _activeThreads.TryAdd(thread.ManagedThreadId, thread); + + // Initialize the GC timer if necessary. + if (_activeThreadGCTimer == null) + { + lock (_activeThreadGCTimerLock) + { + if (_activeThreadGCTimer == null && _activeThreads.Count > 0) + { + _activeThreadGCTimer = new Timer( + (state) => GCThreads(), + null, + _threadGCInterval, + _threadGCInterval); + } + } + } + + return returnValue; + } + + /// + /// Try to remove a thread from the pool. If the removal is successful, then the + /// corresponding JVM thread will also be disposed. + /// + /// The ID of the thread to remove. + /// True if success, false if the thread cannot be found. + private bool TryDisposeJvmThread(int threadId) + { + if (_activeThreads.TryRemove(threadId, out _)) + { + // _activeThreads does not have ownership of the threads on the .NET side. This + // class does not need to call Join() on the .NET Thread. However, this class is + // responsible for sending the rmThread command to the JVM to trigger disposal + // of the corresponding JVM thread. + if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId)) + { + _loggerService.LogDebug($"GC'd JVM thread {threadId}."); + return true; + } + else + { + _loggerService.LogWarn( + $"rmThread returned false for JVM thread {threadId}. " + + $"Either thread does not exist or has already been GC'd."); + } + } + + return false; + } + + /// + /// Remove any threads that are no longer active. + /// + private void GCThreads() + { + foreach (KeyValuePair kvp in _activeThreads) + { + if (!kvp.Value.IsAlive) + { + TryDisposeJvmThread(kvp.Key); + } + } + + lock (_activeThreadGCTimerLock) + { + // Dispose of the timer if there are no threads to monitor. + if (_activeThreadGCTimer != null && _activeThreads.IsEmpty) + { + _activeThreadGCTimer.Dispose(); + _activeThreadGCTimer = null; + } + } + } + } +} diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs index 3d7315b2a..569744713 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs @@ -27,7 +27,7 @@ internal class PayloadHelper private static readonly byte[] s_timestampTypeId = new[] { (byte)'t' }; private static readonly byte[] s_jvmObjectTypeId = new[] { (byte)'j' }; private static readonly byte[] s_byteArrayTypeId = new[] { (byte)'r' }; - private static readonly byte[] s_doubleArrayArrayTypeId = new[] { ( byte)'A' }; + private static readonly byte[] s_doubleArrayArrayTypeId = new[] { (byte)'A' }; private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' }; private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' }; private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' }; @@ -39,6 +39,7 @@ internal class PayloadHelper internal static void BuildPayload( MemoryStream destination, bool isStaticMethod, + int threadId, object classNameOrJvmObjectReference, string methodName, object[] args) @@ -48,6 +49,7 @@ internal static void BuildPayload( destination.Position += sizeof(int); SerDe.Write(destination, isStaticMethod); + SerDe.Write(destination, threadId); SerDe.Write(destination, classNameOrJvmObjectReference.ToString()); SerDe.Write(destination, methodName); SerDe.Write(destination, args.Length); @@ -140,7 +142,7 @@ internal static void ConvertArgsToBytes( SerDe.Write(destination, d); } break; - + case double[][] argDoubleArrayArray: SerDe.Write(destination, s_doubleArrayArrayTypeId); SerDe.Write(destination, argDoubleArrayArray.Length); diff --git a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs index 3b7de1555..4ce565c84 100644 --- a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs +++ b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs @@ -33,6 +33,18 @@ internal sealed class ConfigurationService : IConfigurationService private string _workerPath; + /// + /// How often to run GC on JVM ThreadPool threads. Defaults to 5 minutes. + /// + public TimeSpan JvmThreadGCInterval + { + get + { + string envVar = Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL"); + return string.IsNullOrEmpty(envVar) ? TimeSpan.FromMinutes(5) : TimeSpan.Parse(envVar); + } + } + /// /// Returns the port number for socket communication between JVM and CLR. /// diff --git a/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs b/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs index 5c7a4074f..5398632bd 100644 --- a/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs +++ b/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; + namespace Microsoft.Spark.Services { /// @@ -9,6 +11,11 @@ namespace Microsoft.Spark.Services /// internal interface IConfigurationService { + /// + /// How often to run GC on JVM ThreadPool threads. + /// + TimeSpan JvmThreadGCInterval { get; } + /// /// The port number used for communicating with the .NET backend process. /// diff --git a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs index cc2e70e26..c90bc3ce9 100644 --- a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs +++ b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs @@ -61,10 +61,40 @@ internal SparkSession(JvmObjectReference jvmObject) /// Builder object public static Builder Builder() => new Builder(); - /// Note that *ActiveSession() APIs are not exposed because these APIs work with a - /// thread-local variable, which stores the session variable. Since the Netty server - /// that handles the requests is multi-threaded, any thread can invoke these APIs, - /// resulting in unexpected behaviors if different threads are used. + /// + /// Changes the SparkSession that will be returned in this thread when + /// is called. This can be used to ensure that a given + /// thread receives a SparkSession with an isolated session, instead of the global + /// (first created) context. + /// + /// SparkSession object + public static void SetActiveSession(SparkSession session) => + session._jvmObject.Jvm.CallStaticJavaMethod( + s_sparkSessionClassName, "setActiveSession", session); + + /// + /// Clears the active SparkSession for current thread. Subsequent calls to + /// will return the first created context + /// instead of a thread-local override. + /// + public static void ClearActiveSession() => + SparkEnvironment.JvmBridge.CallStaticJavaMethod( + s_sparkSessionClassName, "clearActiveSession"); + + /// + /// Returns the active SparkSession for the current thread, returned by the builder. + /// + /// Return null, when calling this function on executors + public static SparkSession GetActiveSession() + { + var optionalSession = new Option( + (JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod( + s_sparkSessionClassName, "getActiveSession")); + + return optionalSession.IsDefined() + ? new SparkSession((JvmObjectReference)optionalSession.Get()) + : null; + } /// /// Sets the default SparkSession that is returned by the builder. diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 1cde1d1c5..e632589e4 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.collection.mutable.HashMap +import scala.language.existentials + import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -import scala.collection.mutable.HashMap -import scala.language.existentials - /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. @@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend) // First bit is isStatic val isStatic = readBoolean(dis) + val threadId = readInt(dis) val objId = readString(dis) val methodName = readString(dis) val numArgs = readInt(dis) @@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend) logError(s"Removing $objId failed", e) writeInt(dos, -1) } + case "rmThread" => + try { + assert(readObjectType(dis) == 'i') + val threadToDelete = readInt(dis) + val result = ThreadPool.tryDeleteThread(threadToDelete) + writeInt(dos, 0) + writeObject(dos, result.asInstanceOf[AnyRef]) + } catch { + case e: Exception => + logError(s"Removing thread $threadId failed", e) + writeInt(dos, -1) + } case "connectCallback" => assert(readObjectType(dis) == 'c') val address = readString(dis) assert(readObjectType(dis) == 'i') val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port); + DotnetBackend.setCallbackClient(address, port) writeInt(dos, 0) writeType(dos, "void") case "closeCallback" => @@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend) case _ => dos.writeInt(-1) } } else { - handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + ThreadPool + .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) } bos.toByteArray @@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend) "invalid method " + methodName + " for object " + objId) } } catch { - case e: Exception => + case e: Throwable => val jvmObj = JVMObjectTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName case None => "NullObject" } - val argsStr = args.map(arg => { - if (arg != null) { - s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" - } else { - "[Value: NULL]" - } - }).mkString(", ") + val argsStr = args + .map(arg => { + if (arg != null) { + s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" + } else { + "[Value: NULL]" + } + }) + .mkString(", ") logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala new file mode 100644 index 000000000..1888ec746 --- /dev/null +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.util.concurrent.{ExecutorService, Executors} + +import scala.collection.mutable + +/** + * Pool of thread executors. There should be a 1-1 correspondence between C# threads + * and Java threads. + */ +object ThreadPool { + + /** + * Map from threadId to corresponding executor. + */ + private val executors: mutable.HashMap[Int, ExecutorService] = + new mutable.HashMap[Int, ExecutorService]() + + /** + * Run some code on a particular thread. + * + * @param threadId Integer id of the thread. + * @param task Function to run on the thread. + */ + def run(threadId: Int, task: () => Unit): Unit = { + val executor = getOrCreateExecutor(threadId) + val future = executor.submit(new Runnable { + override def run(): Unit = task() + }) + + future.get() + } + + /** + * Try to delete a particular thread. + * + * @param threadId Integer id of the thread. + * @return True if successful, false if thread does not exist. + */ + def tryDeleteThread(threadId: Int): Boolean = synchronized { + executors.remove(threadId) match { + case Some(executorService) => + executorService.shutdown() + true + case None => false + } + } + + /** + * Get the executor if it exists, otherwise create a new one. + * + * @param id Integer id of the thread. + * @return The new or existing executor with the given id. + */ + private def getOrCreateExecutor(id: Int): ExecutorService = synchronized { + executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor) + } +} diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 1cde1d1c5..e632589e4 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.collection.mutable.HashMap +import scala.language.existentials + import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -import scala.collection.mutable.HashMap -import scala.language.existentials - /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. @@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend) // First bit is isStatic val isStatic = readBoolean(dis) + val threadId = readInt(dis) val objId = readString(dis) val methodName = readString(dis) val numArgs = readInt(dis) @@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend) logError(s"Removing $objId failed", e) writeInt(dos, -1) } + case "rmThread" => + try { + assert(readObjectType(dis) == 'i') + val threadToDelete = readInt(dis) + val result = ThreadPool.tryDeleteThread(threadToDelete) + writeInt(dos, 0) + writeObject(dos, result.asInstanceOf[AnyRef]) + } catch { + case e: Exception => + logError(s"Removing thread $threadId failed", e) + writeInt(dos, -1) + } case "connectCallback" => assert(readObjectType(dis) == 'c') val address = readString(dis) assert(readObjectType(dis) == 'i') val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port); + DotnetBackend.setCallbackClient(address, port) writeInt(dos, 0) writeType(dos, "void") case "closeCallback" => @@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend) case _ => dos.writeInt(-1) } } else { - handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + ThreadPool + .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) } bos.toByteArray @@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend) "invalid method " + methodName + " for object " + objId) } } catch { - case e: Exception => + case e: Throwable => val jvmObj = JVMObjectTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName case None => "NullObject" } - val argsStr = args.map(arg => { - if (arg != null) { - s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" - } else { - "[Value: NULL]" - } - }).mkString(", ") + val argsStr = args + .map(arg => { + if (arg != null) { + s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" + } else { + "[Value: NULL]" + } + }) + .mkString(", ") logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala new file mode 100644 index 000000000..1888ec746 --- /dev/null +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.util.concurrent.{ExecutorService, Executors} + +import scala.collection.mutable + +/** + * Pool of thread executors. There should be a 1-1 correspondence between C# threads + * and Java threads. + */ +object ThreadPool { + + /** + * Map from threadId to corresponding executor. + */ + private val executors: mutable.HashMap[Int, ExecutorService] = + new mutable.HashMap[Int, ExecutorService]() + + /** + * Run some code on a particular thread. + * + * @param threadId Integer id of the thread. + * @param task Function to run on the thread. + */ + def run(threadId: Int, task: () => Unit): Unit = { + val executor = getOrCreateExecutor(threadId) + val future = executor.submit(new Runnable { + override def run(): Unit = task() + }) + + future.get() + } + + /** + * Try to delete a particular thread. + * + * @param threadId Integer id of the thread. + * @return True if successful, false if thread does not exist. + */ + def tryDeleteThread(threadId: Int): Boolean = synchronized { + executors.remove(threadId) match { + case Some(executorService) => + executorService.shutdown() + true + case None => false + } + } + + /** + * Get the executor if it exists, otherwise create a new one. + * + * @param id Integer id of the thread. + * @return The new or existing executor with the given id. + */ + private def getOrCreateExecutor(id: Int): ExecutorService = synchronized { + executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor) + } +} diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 1cde1d1c5..1446e5ff6 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend) // First bit is isStatic val isStatic = readBoolean(dis) + val threadId = readInt(dis) val objId = readString(dis) val methodName = readString(dis) val numArgs = readInt(dis) @@ -65,12 +66,24 @@ class DotnetBackendHandler(server: DotnetBackend) logError(s"Removing $objId failed", e) writeInt(dos, -1) } + case "rmThread" => + try { + assert(readObjectType(dis) == 'i') + val threadToDelete = readInt(dis) + val result = ThreadPool.tryDeleteThread(threadToDelete) + writeInt(dos, 0) + writeObject(dos, result.asInstanceOf[AnyRef]) + } catch { + case e: Exception => + logError(s"Removing thread $threadId failed", e) + writeInt(dos, -1) + } case "connectCallback" => assert(readObjectType(dis) == 'c') val address = readString(dis) assert(readObjectType(dis) == 'i') val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port); + DotnetBackend.setCallbackClient(address, port) writeInt(dos, 0) writeType(dos, "void") case "closeCallback" => @@ -82,7 +95,8 @@ class DotnetBackendHandler(server: DotnetBackend) case _ => dos.writeInt(-1) } } else { - handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + ThreadPool + .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) } bos.toByteArray @@ -162,19 +176,21 @@ class DotnetBackendHandler(server: DotnetBackend) "invalid method " + methodName + " for object " + objId) } } catch { - case e: Exception => + case e: Throwable => val jvmObj = JVMObjectTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName case None => "NullObject" } - val argsStr = args.map(arg => { - if (arg != null) { - s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" - } else { - "[Value: NULL]" - } - }).mkString(", ") + val argsStr = args + .map(arg => { + if (arg != null) { + s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" + } else { + "[Value: NULL]" + } + }) + .mkString(", ") logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala new file mode 100644 index 000000000..1888ec746 --- /dev/null +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.util.concurrent.{ExecutorService, Executors} + +import scala.collection.mutable + +/** + * Pool of thread executors. There should be a 1-1 correspondence between C# threads + * and Java threads. + */ +object ThreadPool { + + /** + * Map from threadId to corresponding executor. + */ + private val executors: mutable.HashMap[Int, ExecutorService] = + new mutable.HashMap[Int, ExecutorService]() + + /** + * Run some code on a particular thread. + * + * @param threadId Integer id of the thread. + * @param task Function to run on the thread. + */ + def run(threadId: Int, task: () => Unit): Unit = { + val executor = getOrCreateExecutor(threadId) + val future = executor.submit(new Runnable { + override def run(): Unit = task() + }) + + future.get() + } + + /** + * Try to delete a particular thread. + * + * @param threadId Integer id of the thread. + * @return True if successful, false if thread does not exist. + */ + def tryDeleteThread(threadId: Int): Boolean = synchronized { + executors.remove(threadId) match { + case Some(executorService) => + executorService.shutdown() + true + case None => false + } + } + + /** + * Get the executor if it exists, otherwise create a new one. + * + * @param id Integer id of the thread. + * @return The new or existing executor with the given id. + */ + private def getOrCreateExecutor(id: Int): ExecutorService = synchronized { + executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor) + } +}