From fdc1b20e26bd7c22a3b9f31ca72b7c9e07b04b5c Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 24 Aug 2020 21:13:25 -0700 Subject: [PATCH 01/42] Embed thread ID in payload --- src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs index 8580efcae..cf61f33f2 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Threading; using Microsoft.Spark.Sql; using Microsoft.Spark.Sql.Types; @@ -27,7 +28,7 @@ internal class PayloadHelper private static readonly byte[] s_timestampTypeId = new[] { (byte)'t' }; private static readonly byte[] s_jvmObjectTypeId = new[] { (byte)'j' }; private static readonly byte[] s_byteArrayTypeId = new[] { (byte)'r' }; - private static readonly byte[] s_doubleArrayArrayTypeId = new[] { ( byte)'A' }; + private static readonly byte[] s_doubleArrayArrayTypeId = new[] { (byte)'A' }; private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' }; private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' }; private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' }; @@ -47,6 +48,7 @@ internal static void BuildPayload( destination.Position += sizeof(int); SerDe.Write(destination, isStaticMethod); + SerDe.Write(destination, Thread.CurrentThread.ManagedThreadId); SerDe.Write(destination, classNameOrJvmObjectReference.ToString()); SerDe.Write(destination, methodName); SerDe.Write(destination, args.Length); @@ -139,7 +141,7 @@ internal static void ConvertArgsToBytes( SerDe.Write(destination, d); } break; - + case double[][] argDoubleArrayArray: SerDe.Write(destination, s_doubleArrayArrayTypeId); SerDe.Write(destination, argDoubleArrayArray.Length); From 88befdbfb574db90ad39dbfd3a064a29eaac9e9d Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 24 Aug 2020 21:21:11 -0700 Subject: [PATCH 02/42] Basic execution threadpool --- .../api/dotnet/DotnetBackendHandler.scala | 39 +++++++++----- .../apache/spark/api/dotnet/ThreadPool.scala | 52 +++++++++++++++++++ 2 files changed, 79 insertions(+), 12 deletions(-) create mode 100644 src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 1cde1d1c5..c89249603 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.collection.mutable.HashMap +import scala.language.existentials + import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -import scala.collection.mutable.HashMap -import scala.language.existentials - /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. @@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend) // First bit is isStatic val isStatic = readBoolean(dis) + val threadId = readInt(dis) val objId = readString(dis) val methodName = readString(dis) val numArgs = readInt(dis) @@ -57,7 +58,7 @@ class DotnetBackendHandler(server: DotnetBackend) val t = readObjectType(dis) assert(t == 'c') val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) + ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)).wait() writeInt(dos, 0) writeObject(dos, null) } catch { @@ -65,6 +66,16 @@ class DotnetBackendHandler(server: DotnetBackend) logError(s"Removing $objId failed", e) writeInt(dos, -1) } + case "rmThread" => + try { + ThreadPool.deleteThread(threadId) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing thread $threadId failed", e) + writeInt(dos, -1) + } case "connectCallback" => assert(readObjectType(dis) == 'c') val address = readString(dis) @@ -82,7 +93,9 @@ class DotnetBackendHandler(server: DotnetBackend) case _ => dos.writeInt(-1) } } else { - handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + ThreadPool + .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) + .wait() } bos.toByteArray @@ -168,13 +181,15 @@ class DotnetBackendHandler(server: DotnetBackend) case Some(jObj) => jObj.getClass.getName case None => "NullObject" } - val argsStr = args.map(arg => { - if (arg != null) { - s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" - } else { - "[Value: NULL]" - } - }).mkString(", ") + val argsStr = args + .map(arg => { + if (arg != null) { + s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" + } else { + "[Value: NULL]" + } + }) + .mkString(", ") logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala new file mode 100644 index 000000000..148df5bc3 --- /dev/null +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -0,0 +1,52 @@ +package org.apache.spark.api.dotnet + +import java.util.concurrent.{ExecutorService, Executors, Future} + +import scala.collection.mutable + +/** + * Pool of thread executors. There should be a 1-1 correspondence between C# threads + * and Java threads. + */ +object ThreadPool { + + /** + * Map from threadId to corresponding executor. + */ + val executors: mutable.Map[Int, ExecutorService] = mutable.Map() + + /** + * Run some code on a particular thread. + * + * @param threadId + * @param task + * @return + */ + def run(threadId: Int, task: () => Unit): Future[_] = + getOrCreateExecutor(threadId).submit(new Runnable { + override def run(): Unit = task + }) + + /** + * Delete a particular thread. + * + * @param threadId + */ + def deleteThread(threadId: Int) = { + getOrCreateExecutor(threadId).shutdown() + executors.remove(threadId) + } + + /** + * Get the executor if it exists, otherwise create a new one. + * + * @param id + * @return + */ + private def getOrCreateExecutor(id: Int): ExecutorService = + executors.getOrElse(id, { + val thread = Executors.newSingleThreadExecutor() + executors.put(id, thread) + thread + }) +} From d9a403c15e694feec423cc7034343421840551cc Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 24 Aug 2020 21:51:25 -0700 Subject: [PATCH 03/42] Thread lifecycle management --- .../Microsoft.Spark/Interop/Ipc/JvmBridge.cs | 21 +++++++++++++++++++ .../Interop/Ipc/PayloadHelper.cs | 3 ++- .../api/dotnet/DotnetBackendHandler.scala | 4 +++- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index 231263c74..e65354281 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -8,6 +8,7 @@ using System.IO; using System.Net; using System.Text; +using System.Threading; using Microsoft.Spark.Network; using Microsoft.Spark.Services; @@ -35,6 +36,8 @@ internal sealed class JvmBridge : IJvmBridge private readonly ILoggerService _logger = LoggerServiceFactory.GetLogger(typeof(JvmBridge)); private readonly int _portNumber; + private readonly HashSet _activeThreads; + private readonly Timer _activeThreadMonitor; internal JvmBridge(int portNumber) { @@ -45,6 +48,18 @@ internal JvmBridge(int portNumber) _portNumber = portNumber; _logger.LogInfo($"JvMBridge port is {portNumber}"); + + _activeThreads = new HashSet(); + _activeThreadMonitor = new Timer((state) => + { + foreach (Thread thread in _activeThreads) + { + if (!thread.IsAlive) + { + CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId); + } + } + }, null, 0, 10000); } private ISocketWrapper GetConnection() @@ -158,12 +173,14 @@ private object CallJavaMethod( ISocketWrapper socket = null; try { + Thread thread = Thread.CurrentThread; MemoryStream payloadMemoryStream = s_payloadMemoryStream ?? (s_payloadMemoryStream = new MemoryStream()); payloadMemoryStream.Position = 0; PayloadHelper.BuildPayload( payloadMemoryStream, isStatic, + thread.ManagedThreadId, classNameOrJvmObjectReference, methodName, args); @@ -177,6 +194,8 @@ private object CallJavaMethod( (int)payloadMemoryStream.Position); outputStream.Flush(); + _activeThreads.Add(thread); + Stream inputStream = socket.InputStream; int isMethodCallFailed = SerDe.ReadInt32(inputStream); if (isMethodCallFailed != 0) @@ -418,6 +437,8 @@ public void Dispose() socket.Dispose(); } } + + _activeThreadMonitor.Dispose(); } } } diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs index cf61f33f2..31b63f521 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs @@ -39,6 +39,7 @@ internal class PayloadHelper internal static void BuildPayload( MemoryStream destination, bool isStaticMethod, + int threadId, object classNameOrJvmObjectReference, string methodName, object[] args) @@ -48,7 +49,7 @@ internal static void BuildPayload( destination.Position += sizeof(int); SerDe.Write(destination, isStaticMethod); - SerDe.Write(destination, Thread.CurrentThread.ManagedThreadId); + SerDe.Write(destination, threadId); SerDe.Write(destination, classNameOrJvmObjectReference.ToString()); SerDe.Write(destination, methodName); SerDe.Write(destination, args.Length); diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index c89249603..17f7f2384 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -68,7 +68,9 @@ class DotnetBackendHandler(server: DotnetBackend) } case "rmThread" => try { - ThreadPool.deleteThread(threadId) + assert(readObjectType(dis) == 'i') + val threadToDelete = readInt(dis) + ThreadPool.deleteThread(threadToDelete) writeInt(dos, 0) writeObject(dos, null) } catch { From 33c0202954214b2dd80674b600e29b8fe5623f0e Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Tue, 25 Aug 2020 11:53:31 -0700 Subject: [PATCH 04/42] Port to spark 2.3 --- .../api/dotnet/DotnetBackendHandler.scala | 41 ++++++++++----- .../apache/spark/api/dotnet/ThreadPool.scala | 52 +++++++++++++++++++ 2 files changed, 81 insertions(+), 12 deletions(-) create mode 100644 src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 1cde1d1c5..17f7f2384 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.collection.mutable.HashMap +import scala.language.existentials + import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -import scala.collection.mutable.HashMap -import scala.language.existentials - /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. @@ -42,6 +42,7 @@ class DotnetBackendHandler(server: DotnetBackend) // First bit is isStatic val isStatic = readBoolean(dis) + val threadId = readInt(dis) val objId = readString(dis) val methodName = readString(dis) val numArgs = readInt(dis) @@ -57,7 +58,7 @@ class DotnetBackendHandler(server: DotnetBackend) val t = readObjectType(dis) assert(t == 'c') val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) + ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)).wait() writeInt(dos, 0) writeObject(dos, null) } catch { @@ -65,6 +66,18 @@ class DotnetBackendHandler(server: DotnetBackend) logError(s"Removing $objId failed", e) writeInt(dos, -1) } + case "rmThread" => + try { + assert(readObjectType(dis) == 'i') + val threadToDelete = readInt(dis) + ThreadPool.deleteThread(threadToDelete) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing thread $threadId failed", e) + writeInt(dos, -1) + } case "connectCallback" => assert(readObjectType(dis) == 'c') val address = readString(dis) @@ -82,7 +95,9 @@ class DotnetBackendHandler(server: DotnetBackend) case _ => dos.writeInt(-1) } } else { - handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + ThreadPool + .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) + .wait() } bos.toByteArray @@ -168,13 +183,15 @@ class DotnetBackendHandler(server: DotnetBackend) case Some(jObj) => jObj.getClass.getName case None => "NullObject" } - val argsStr = args.map(arg => { - if (arg != null) { - s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" - } else { - "[Value: NULL]" - } - }).mkString(", ") + val argsStr = args + .map(arg => { + if (arg != null) { + s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" + } else { + "[Value: NULL]" + } + }) + .mkString(", ") logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala new file mode 100644 index 000000000..148df5bc3 --- /dev/null +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -0,0 +1,52 @@ +package org.apache.spark.api.dotnet + +import java.util.concurrent.{ExecutorService, Executors, Future} + +import scala.collection.mutable + +/** + * Pool of thread executors. There should be a 1-1 correspondence between C# threads + * and Java threads. + */ +object ThreadPool { + + /** + * Map from threadId to corresponding executor. + */ + val executors: mutable.Map[Int, ExecutorService] = mutable.Map() + + /** + * Run some code on a particular thread. + * + * @param threadId + * @param task + * @return + */ + def run(threadId: Int, task: () => Unit): Future[_] = + getOrCreateExecutor(threadId).submit(new Runnable { + override def run(): Unit = task + }) + + /** + * Delete a particular thread. + * + * @param threadId + */ + def deleteThread(threadId: Int) = { + getOrCreateExecutor(threadId).shutdown() + executors.remove(threadId) + } + + /** + * Get the executor if it exists, otherwise create a new one. + * + * @param id + * @return + */ + private def getOrCreateExecutor(id: Int): ExecutorService = + executors.getOrElse(id, { + val thread = Executors.newSingleThreadExecutor() + executors.put(id, thread) + thread + }) +} From a0740289c540ace9c1416460b1bc28188ef26e4d Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Tue, 25 Aug 2020 13:34:43 -0700 Subject: [PATCH 05/42] Concurrent dictionary for active threads --- .../Microsoft.Spark/Interop/Ipc/JvmBridge.cs | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index e65354281..f87f6567f 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -36,7 +36,7 @@ internal sealed class JvmBridge : IJvmBridge private readonly ILoggerService _logger = LoggerServiceFactory.GetLogger(typeof(JvmBridge)); private readonly int _portNumber; - private readonly HashSet _activeThreads; + private readonly ConcurrentDictionary _activeThreads; private readonly Timer _activeThreadMonitor; internal JvmBridge(int portNumber) @@ -49,17 +49,24 @@ internal JvmBridge(int portNumber) _portNumber = portNumber; _logger.LogInfo($"JvMBridge port is {portNumber}"); - _activeThreads = new HashSet(); + _activeThreads = new ConcurrentDictionary(); _activeThreadMonitor = new Timer((state) => { - foreach (Thread thread in _activeThreads) + foreach (var threadId in _activeThreads.Keys) { - if (!thread.IsAlive) + if (_activeThreads.TryRemove(threadId, out Thread thread)) { - CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId); + if (thread.IsAlive) + { + _activeThreads.TryAdd(threadId, thread); + } + else + { + CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId); + } } } - }, null, 0, 10000); + }, null, 0, 30000); } private ISocketWrapper GetConnection() @@ -194,7 +201,7 @@ private object CallJavaMethod( (int)payloadMemoryStream.Position); outputStream.Flush(); - _activeThreads.Add(thread); + _activeThreads.TryAdd(thread.ManagedThreadId, thread); Stream inputStream = socket.InputStream; int isMethodCallFailed = SerDe.ReadInt32(inputStream); From 6756f13f7f73f0047dcc8c02706fdabf55687856 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 7 Sep 2020 19:49:04 -0700 Subject: [PATCH 06/42] Logic to clean up expired threads --- .../Microsoft.Spark/Interop/Ipc/JvmBridge.cs | 30 ++++++++++--------- .../api/dotnet/DotnetBackendHandler.scala | 3 +- .../apache/spark/api/dotnet/ThreadPool.scala | 21 ++++++++----- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index f87f6567f..b722b45cf 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -37,7 +37,7 @@ internal sealed class JvmBridge : IJvmBridge LoggerServiceFactory.GetLogger(typeof(JvmBridge)); private readonly int _portNumber; private readonly ConcurrentDictionary _activeThreads; - private readonly Timer _activeThreadMonitor; + private readonly Thread _activeThreadMonitor; internal JvmBridge(int portNumber) { @@ -50,23 +50,27 @@ internal JvmBridge(int portNumber) _logger.LogInfo($"JvMBridge port is {portNumber}"); _activeThreads = new ConcurrentDictionary(); - _activeThreadMonitor = new Timer((state) => + _activeThreadMonitor = new Thread(delegate () { - foreach (var threadId in _activeThreads.Keys) + using var timer = new Timer((state) => { - if (_activeThreads.TryRemove(threadId, out Thread thread)) + foreach (var threadId in _activeThreads.Keys) { - if (thread.IsAlive) + if (_activeThreads.TryRemove(threadId, out Thread thread)) { - _activeThreads.TryAdd(threadId, thread); - } - else - { - CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId); + if (thread.IsAlive) + { + _activeThreads.TryAdd(threadId, thread); + } + else + { + CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId); + } } } - } - }, null, 0, 30000); + }, null, 0, 30000); + }); + _activeThreadMonitor.Start(); } private ISocketWrapper GetConnection() @@ -444,8 +448,6 @@ public void Dispose() socket.Dispose(); } } - - _activeThreadMonitor.Dispose(); } } } diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 17f7f2384..7b5758822 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -58,7 +58,7 @@ class DotnetBackendHandler(server: DotnetBackend) val t = readObjectType(dis) assert(t == 'c') val objToRemove = readString(dis) - ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)).wait() + ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)) writeInt(dos, 0) writeObject(dos, null) } catch { @@ -97,7 +97,6 @@ class DotnetBackendHandler(server: DotnetBackend) } else { ThreadPool .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) - .wait() } bos.toByteArray diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 148df5bc3..87ec0f04b 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -1,8 +1,8 @@ package org.apache.spark.api.dotnet -import java.util.concurrent.{ExecutorService, Executors, Future} +import java.util.concurrent.{ExecutorService, Executors} -import scala.collection.mutable +import scala.collection._ /** * Pool of thread executors. There should be a 1-1 correspondence between C# threads @@ -13,19 +13,24 @@ object ThreadPool { /** * Map from threadId to corresponding executor. */ - val executors: mutable.Map[Int, ExecutorService] = mutable.Map() + val executors: concurrent.TrieMap[Int, ExecutorService] = + new concurrent.TrieMap[Int, ExecutorService]() /** * Run some code on a particular thread. * * @param threadId * @param task - * @return */ - def run(threadId: Int, task: () => Unit): Future[_] = - getOrCreateExecutor(threadId).submit(new Runnable { - override def run(): Unit = task - }) + def run(threadId: Int, task: () => Unit): Unit = { + val runnable = new Runnable { + override def run(): Unit = task() + } + val future = getOrCreateExecutor(threadId).submit(runnable) + while (!future.isDone) { + Thread.sleep(1000) + } + } /** * Delete a particular thread. From c875e63ae86877b57324173ecb81b7e420da6575 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 7 Sep 2020 19:53:26 -0700 Subject: [PATCH 07/42] Copy changes to Spark 2.3 --- .../api/dotnet/DotnetBackendHandler.scala | 595 +++++++++--------- 1 file changed, 297 insertions(+), 298 deletions(-) diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 17f7f2384..2412a7612 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -22,309 +22,308 @@ import org.apache.spark.util.Utils */ class DotnetBackendHandler(server: DotnetBackend) extends SimpleChannelInboundHandler[Array[Byte]] - with Logging { - - override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { - val reply = handleBackendRequest(msg) - ctx.write(reply) - } - - override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { - ctx.flush() - } - - def handleBackendRequest(msg: Array[Byte]): Array[Byte] = { - val bis = new ByteArrayInputStream(msg) - val dis = new DataInputStream(bis) - - val bos = new ByteArrayOutputStream() - val dos = new DataOutputStream(bos) - - // First bit is isStatic - val isStatic = readBoolean(dis) - val threadId = readInt(dis) - val objId = readString(dis) - val methodName = readString(dis) - val numArgs = readInt(dis) - - if (objId == "DotnetHandler") { - methodName match { - case "stopBackend" => - writeInt(dos, 0) - writeType(dos, "void") - server.close() - case "rm" => - try { - val t = readObjectType(dis) - assert(t == 'c') - val objToRemove = readString(dis) - ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)).wait() - writeInt(dos, 0) - writeObject(dos, null) - } catch { - case e: Exception => - logError(s"Removing $objId failed", e) - writeInt(dos, -1) - } - case "rmThread" => - try { - assert(readObjectType(dis) == 'i') - val threadToDelete = readInt(dis) - ThreadPool.deleteThread(threadToDelete) - writeInt(dos, 0) - writeObject(dos, null) - } catch { - case e: Exception => - logError(s"Removing thread $threadId failed", e) - writeInt(dos, -1) - } - case "connectCallback" => - assert(readObjectType(dis) == 'c') - val address = readString(dis) - assert(readObjectType(dis) == 'i') - val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port); - writeInt(dos, 0) - writeType(dos, "void") - case "closeCallback" => - logInfo("Requesting to close callback client") - DotnetBackend.shutdownCallbackClient() - writeInt(dos, 0) - writeType(dos, "void") - - case _ => dos.writeInt(-1) - } - } else { - ThreadPool - .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) - .wait() - } + with Logging { - bos.toByteArray - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - // Skip logging the exception message if the connection was disconnected from - // the .NET side so that .NET side doesn't have to explicitly close the connection via - // "stopBackend." Note that an exception is still thrown if the exit status is non-zero, - // so skipping this kind of exception message does not affect the debugging. - if (!cause.getMessage.contains( - "An existing connection was forcibly closed by the remote host")) { - logError("Exception caught: ", cause) + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val reply = handleBackendRequest(msg) + ctx.write(reply) } - // Close the connection when an exception is raised. - ctx.close() - } - - def handleMethodCall( - isStatic: Boolean, - objId: String, - methodName: String, - numArgs: Int, - dis: DataInputStream, - dos: DataOutputStream): Unit = { - var obj: Object = null - var args: Array[java.lang.Object] = null - var methods: Array[java.lang.reflect.Method] = null - - try { - val cls = if (isStatic) { - Utils.classForName(objId) - } else { - JVMObjectTracker.get(objId) match { - case None => throw new IllegalArgumentException("Object not found " + objId) - case Some(o) => - obj = o - o.getClass - } - } - - args = readArgs(numArgs, dis) - methods = cls.getMethods - - val selectedMethods = methods.filter(m => m.getName == methodName) - if (selectedMethods.length > 0) { - val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args) - - if (index.isEmpty) { - logWarning( - s"cannot find matching method ${cls}.$methodName. " - + s"Candidates are:") - selectedMethods.foreach { method => - logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") - } - throw new Exception(s"No matched method found for $cls.$methodName") - } + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } - val ret = selectedMethods(index.get).invoke(obj, args: _*) - - // Write status bit - writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) - } else if (methodName == "") { - // methodName should be "" for constructor - val ctor = cls.getConstructors.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - }.head - - val obj = ctor.newInstance(args: _*) - - writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) - } else { - throw new IllegalArgumentException( - "invalid method " + methodName + " for object " + objId) - } - } catch { - case e: Exception => - val jvmObj = JVMObjectTracker.get(objId) - val jvmObjName = jvmObj match { - case Some(jObj) => jObj.getClass.getName - case None => "NullObject" - } - val argsStr = args - .map(arg => { - if (arg != null) { - s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" - } else { - "[Value: NULL]" + def handleBackendRequest(msg: Array[Byte]): Array[Byte] = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = readBoolean(dis) + val threadId = readInt(dis) + val objId = readString(dis) + val methodName = readString(dis) + val numArgs = readInt(dis) + + if (objId == "DotnetHandler") { + methodName match { + case "stopBackend" => + writeInt(dos, 0) + writeType(dos, "void") + server.close() + case "rm" => + try { + val t = readObjectType(dis) + assert(t == 'c') + val objToRemove = readString(dis) + ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + writeInt(dos, -1) + } + case "rmThread" => + try { + assert(readObjectType(dis) == 'i') + val threadToDelete = readInt(dis) + ThreadPool.deleteThread(threadToDelete) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing thread $threadId failed", e) + writeInt(dos, -1) + } + case "connectCallback" => + assert(readObjectType(dis) == 'c') + val address = readString(dis) + assert(readObjectType(dis) == 'i') + val port = readInt(dis) + DotnetBackend.setCallbackClient(address, port); + writeInt(dos, 0) + writeType(dos, "void") + case "closeCallback" => + logInfo("Requesting to close callback client") + DotnetBackend.shutdownCallbackClient() + writeInt(dos, 0) + writeType(dos, "void") + + case _ => dos.writeInt(-1) } - }) - .mkString(", ") + } else { + ThreadPool + .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) + } - logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") + bos.toByteArray + } - if (methods != null) { - logDebug(s"All methods for $jvmObjName:") - methods.foreach(m => logDebug(m.toString)) + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Skip logging the exception message if the connection was disconnected from + // the .NET side so that .NET side doesn't have to explicitly close the connection via + // "stopBackend." Note that an exception is still thrown if the exit status is non-zero, + // so skipping this kind of exception message does not affect the debugging. + if (!cause.getMessage.contains( + "An existing connection was forcibly closed by the remote host")) { + logError("Exception caught: ", cause) } - writeInt(dos, -1) - writeString(dos, Utils.exceptionString(e.getCause)) - } - } - - // Read a number of arguments from the data input stream - def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { - (0 until numArgs).map { arg => - readObject(dis) - }.toArray - } - - // Checks if the arguments passed in args matches the parameter types. - // NOTE: Currently we do exact match. We may add type conversions later. - def matchMethod( - numArgs: Int, - args: Array[java.lang.Object], - parameterTypes: Array[Class[_]]): Boolean = { - if (parameterTypes.length != numArgs) { - return false + // Close the connection when an exception is raised. + ctx.close() } - for (i <- 0 until numArgs) { - val parameterType = parameterTypes(i) - var parameterWrapperType = parameterType - - // Convert native parameters to Object types as args is Array[Object] here - if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Long] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType - } - } + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + var args: Array[java.lang.Object] = null + var methods: Array[java.lang.reflect.Method] = null + + try { + val cls = if (isStatic) { + Utils.classForName(objId) + } else { + JVMObjectTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass + } + } - if (!parameterWrapperType.isInstance(args(i))) { - // non primitive types - if (!parameterType.isPrimitive && args(i) != null) { - return false + args = readArgs(numArgs, dis) + methods = cls.getMethods + + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args) + + if (index.isEmpty) { + logWarning( + s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + + val ret = selectedMethods(index.get).invoke(obj, args: _*) + + // Write status bit + writeInt(dos, 0) + writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args: _*) + + writeInt(dos, 0) + writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException( + "invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Exception => + val jvmObj = JVMObjectTracker.get(objId) + val jvmObjName = jvmObj match { + case Some(jObj) => jObj.getClass.getName + case None => "NullObject" + } + val argsStr = args + .map(arg => { + if (arg != null) { + s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" + } else { + "[Value: NULL]" + } + }) + .mkString(", ") + + logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") + + if (methods != null) { + logDebug(s"All methods for $jvmObjName:") + methods.foreach(m => logDebug(m.toString)) + } + + writeInt(dos, -1) + writeString(dos, Utils.exceptionString(e.getCause)) } + } - // primitive types - if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) { - return false - } - } + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + readObject(dis) + }.toArray } - true - } - - // Find a matching method signature in an array of signatures of constructors - // or methods of the same name according to the passed arguments. Arguments - // may be converted in order to match a signature. - // - // Note that in Java reflection, constructors and normal methods are of different - // classes, and share no parent class that provides methods for reflection uses. - // There is no unified way to handle them in this function. So an array of signatures - // is passed in instead of an array of candidate constructors or methods. - // - // Returns an Option[Int] which is the index of the matched signature in the array. - def findMatchedSignature( - parameterTypesOfMethods: Array[Array[Class[_]]], - args: Array[Object]): Option[Int] = { - val numArgs = args.length - - for (index <- parameterTypesOfMethods.indices) { - val parameterTypes = parameterTypesOfMethods(index) - - if (parameterTypes.length == numArgs) { - var argMatched = true - var i = 0 - while (i < numArgs && argMatched) { - val parameterType = parameterTypes(i) - - if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { - // The case that the parameter type is a Scala Seq and the argument - // is a Java array is considered matching. The array will be converted - // to a Seq later if this method is matched. - } else { + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false + } + + for (i <- 0 until numArgs) { + val parameterType = parameterTypes(i) var parameterWrapperType = parameterType // Convert native parameters to Object types as args is Array[Object] here if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Long] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType - } - } - if ((parameterType.isPrimitive || args(i) != null) && - !parameterWrapperType.isInstance(args(i))) { - argMatched = false + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } } - } - i = i + 1 - } + if (!parameterWrapperType.isInstance(args(i))) { + // non primitive types + if (!parameterType.isPrimitive && args(i) != null) { + return false + } - if (argMatched) { - // For now, we return the first matching method. - // TODO: find best method in matching methods. + // primitive types + if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) { + return false + } + } + } - // Convert args if needed - val parameterTypes = parameterTypesOfMethods(index) + true + } - for (i <- 0 until numArgs) { - if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { - // Convert a Java array to scala Seq - args(i) = args(i).asInstanceOf[Array[_]].toSeq + // Find a matching method signature in an array of signatures of constructors + // or methods of the same name according to the passed arguments. Arguments + // may be converted in order to match a signature. + // + // Note that in Java reflection, constructors and normal methods are of different + // classes, and share no parent class that provides methods for reflection uses. + // There is no unified way to handle them in this function. So an array of signatures + // is passed in instead of an array of candidate constructors or methods. + // + // Returns an Option[Int] which is the index of the matched signature in the array. + def findMatchedSignature( + parameterTypesOfMethods: Array[Array[Class[_]]], + args: Array[Object]): Option[Int] = { + val numArgs = args.length + + for (index <- parameterTypesOfMethods.indices) { + val parameterTypes = parameterTypesOfMethods(index) + + if (parameterTypes.length == numArgs) { + var argMatched = true + var i = 0 + while (i < numArgs && argMatched) { + val parameterType = parameterTypes(i) + + if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { + // The case that the parameter type is a Scala Seq and the argument + // is a Java array is considered matching. The array will be converted + // to a Seq later if this method is matched. + } else { + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if ((parameterType.isPrimitive || args(i) != null) && + !parameterWrapperType.isInstance(args(i))) { + argMatched = false + } + } + + i = i + 1 + } + + if (argMatched) { + // For now, we return the first matching method. + // TODO: find best method in matching methods. + + // Convert args if needed + val parameterTypes = parameterTypesOfMethods(index) + + for (i <- 0 until numArgs) { + if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { + // Convert a Java array to scala Seq + args(i) = args(i).asInstanceOf[Array[_]].toSeq + } + } + + return Some(index) + } } - } - - return Some(index) } - } + None } - None - } - def logError(id: String, e: Exception): Unit = {} + def logError(id: String, e: Exception): Unit = {} } /** @@ -332,35 +331,35 @@ class DotnetBackendHandler(server: DotnetBackend) */ private object JVMObjectTracker { - // Multiple threads may access objMap and increase objCounter. Because get method return Option, - // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. - private[this] val objMap = new HashMap[String, Object] - private[this] var objCounter: Int = 1 + // Multiple threads may access objMap and increase objCounter. Because get method return Option, + // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. + private[this] val objMap = new HashMap[String, Object] + private[this] var objCounter: Int = 1 - def getObject(id: String): Object = { - synchronized { - objMap(id) + def getObject(id: String): Object = { + synchronized { + objMap(id) + } } - } - def get(id: String): Option[Object] = { - synchronized { - objMap.get(id) + def get(id: String): Option[Object] = { + synchronized { + objMap.get(id) + } } - } - - def put(obj: Object): String = { - synchronized { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId + + def put(obj: Object): String = { + synchronized { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } } - } - def remove(id: String): Option[Object] = { - synchronized { - objMap.remove(id) + def remove(id: String): Option[Object] = { + synchronized { + objMap.remove(id) + } } - } } From 3e37ca3e01f8372928173e5888d7b9fd33d84718 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Tue, 8 Sep 2020 13:11:37 -0700 Subject: [PATCH 08/42] Update ThreadPool in 2.3 --- .../apache/spark/api/dotnet/ThreadPool.scala | 19 ++++++++++++------- .../apache/spark/api/dotnet/ThreadPool.scala | 6 +++--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 148df5bc3..28c40ecde 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -1,8 +1,8 @@ package org.apache.spark.api.dotnet -import java.util.concurrent.{ExecutorService, Executors, Future} +import java.util.concurrent.{ExecutorService, Executors} -import scala.collection.mutable +import scala.collection._ /** * Pool of thread executors. There should be a 1-1 correspondence between C# threads @@ -13,19 +13,24 @@ object ThreadPool { /** * Map from threadId to corresponding executor. */ - val executors: mutable.Map[Int, ExecutorService] = mutable.Map() + val executors: concurrent.TrieMap[Int, ExecutorService] = + new concurrent.TrieMap[Int, ExecutorService]() /** * Run some code on a particular thread. * * @param threadId * @param task - * @return */ - def run(threadId: Int, task: () => Unit): Future[_] = - getOrCreateExecutor(threadId).submit(new Runnable { - override def run(): Unit = task + def run(threadId: Int, task: () => Unit): Unit = { + val executor = getOrCreateExecutor(threadId) + val future = executor.submit(new Runnable { + override def run(): Unit = task() }) + while (!future.isDone) { + Thread.sleep(1000) + } + } /** * Delete a particular thread. diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 87ec0f04b..28c40ecde 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -23,10 +23,10 @@ object ThreadPool { * @param task */ def run(threadId: Int, task: () => Unit): Unit = { - val runnable = new Runnable { + val executor = getOrCreateExecutor(threadId) + val future = executor.submit(new Runnable { override def run(): Unit = task() - } - val future = getOrCreateExecutor(threadId).submit(runnable) + }) while (!future.isDone) { Thread.sleep(1000) } From e8aa36b25120794bc6cd412f7edacd4ec4797012 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Tue, 8 Sep 2020 18:43:33 -0700 Subject: [PATCH 09/42] Tests should pass now --- .../api/dotnet/DotnetBackendHandler.scala | 594 +++++++++--------- .../apache/spark/api/dotnet/ThreadPool.scala | 2 +- .../api/dotnet/DotnetBackendHandler.scala | 4 +- .../apache/spark/api/dotnet/ThreadPool.scala | 2 +- 4 files changed, 301 insertions(+), 301 deletions(-) diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 2412a7612..5a83b2b2c 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -22,308 +22,308 @@ import org.apache.spark.util.Utils */ class DotnetBackendHandler(server: DotnetBackend) extends SimpleChannelInboundHandler[Array[Byte]] - with Logging { - - override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { - val reply = handleBackendRequest(msg) - ctx.write(reply) + with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val reply = handleBackendRequest(msg) + ctx.write(reply) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } + + def handleBackendRequest(msg: Array[Byte]): Array[Byte] = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = readBoolean(dis) + val threadId = readInt(dis) + val objId = readString(dis) + val methodName = readString(dis) + val numArgs = readInt(dis) + + if (objId == "DotnetHandler") { + methodName match { + case "stopBackend" => + writeInt(dos, 0) + writeType(dos, "void") + server.close() + case "rm" => + try { + val t = readObjectType(dis) + assert(t == 'c') + val objToRemove = readString(dis) + ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + writeInt(dos, -1) + } + case "rmThread" => + try { + assert(readObjectType(dis) == 'i') + val threadToDelete = readInt(dis) + ThreadPool.deleteThread(threadToDelete) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing thread $threadId failed", e) + writeInt(dos, -1) + } + case "connectCallback" => + assert(readObjectType(dis) == 'c') + val address = readString(dis) + assert(readObjectType(dis) == 'i') + val port = readInt(dis) + ThreadPool.run(threadId, () => DotnetBackend.setCallbackClient(address, port)) + writeInt(dos, 0) + writeType(dos, "void") + case "closeCallback" => + logInfo("Requesting to close callback client") + ThreadPool.run(threadId, DotnetBackend.shutdownCallbackClient) + writeInt(dos, 0) + writeType(dos, "void") + + case _ => dos.writeInt(-1) + } + } else { + ThreadPool + .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) } - override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { - ctx.flush() + bos.toByteArray + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Skip logging the exception message if the connection was disconnected from + // the .NET side so that .NET side doesn't have to explicitly close the connection via + // "stopBackend." Note that an exception is still thrown if the exit status is non-zero, + // so skipping this kind of exception message does not affect the debugging. + if (!cause.getMessage.contains( + "An existing connection was forcibly closed by the remote host")) { + logError("Exception caught: ", cause) } - def handleBackendRequest(msg: Array[Byte]): Array[Byte] = { - val bis = new ByteArrayInputStream(msg) - val dis = new DataInputStream(bis) - - val bos = new ByteArrayOutputStream() - val dos = new DataOutputStream(bos) - - // First bit is isStatic - val isStatic = readBoolean(dis) - val threadId = readInt(dis) - val objId = readString(dis) - val methodName = readString(dis) - val numArgs = readInt(dis) - - if (objId == "DotnetHandler") { - methodName match { - case "stopBackend" => - writeInt(dos, 0) - writeType(dos, "void") - server.close() - case "rm" => - try { - val t = readObjectType(dis) - assert(t == 'c') - val objToRemove = readString(dis) - ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)) - writeInt(dos, 0) - writeObject(dos, null) - } catch { - case e: Exception => - logError(s"Removing $objId failed", e) - writeInt(dos, -1) - } - case "rmThread" => - try { - assert(readObjectType(dis) == 'i') - val threadToDelete = readInt(dis) - ThreadPool.deleteThread(threadToDelete) - writeInt(dos, 0) - writeObject(dos, null) - } catch { - case e: Exception => - logError(s"Removing thread $threadId failed", e) - writeInt(dos, -1) - } - case "connectCallback" => - assert(readObjectType(dis) == 'c') - val address = readString(dis) - assert(readObjectType(dis) == 'i') - val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port); - writeInt(dos, 0) - writeType(dos, "void") - case "closeCallback" => - logInfo("Requesting to close callback client") - DotnetBackend.shutdownCallbackClient() - writeInt(dos, 0) - writeType(dos, "void") - - case _ => dos.writeInt(-1) - } - } else { - ThreadPool - .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) + // Close the connection when an exception is raised. + ctx.close() + } + + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + var args: Array[java.lang.Object] = null + var methods: Array[java.lang.reflect.Method] = null + + try { + val cls = if (isStatic) { + Utils.classForName(objId) + } else { + JVMObjectTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass } - - bos.toByteArray - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - // Skip logging the exception message if the connection was disconnected from - // the .NET side so that .NET side doesn't have to explicitly close the connection via - // "stopBackend." Note that an exception is still thrown if the exit status is non-zero, - // so skipping this kind of exception message does not affect the debugging. - if (!cause.getMessage.contains( - "An existing connection was forcibly closed by the remote host")) { - logError("Exception caught: ", cause) + } + + args = readArgs(numArgs, dis) + methods = cls.getMethods + + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args) + + if (index.isEmpty) { + logWarning( + s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") } - // Close the connection when an exception is raised. - ctx.close() - } - - def handleMethodCall( - isStatic: Boolean, - objId: String, - methodName: String, - numArgs: Int, - dis: DataInputStream, - dos: DataOutputStream): Unit = { - var obj: Object = null - var args: Array[java.lang.Object] = null - var methods: Array[java.lang.reflect.Method] = null - - try { - val cls = if (isStatic) { - Utils.classForName(objId) + val ret = selectedMethods(index.get).invoke(obj, args: _*) + + // Write status bit + writeInt(dos, 0) + writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args: _*) + + writeInt(dos, 0) + writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException( + "invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Exception => + val jvmObj = JVMObjectTracker.get(objId) + val jvmObjName = jvmObj match { + case Some(jObj) => jObj.getClass.getName + case None => "NullObject" + } + val argsStr = args + .map(arg => { + if (arg != null) { + s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" } else { - JVMObjectTracker.get(objId) match { - case None => throw new IllegalArgumentException("Object not found " + objId) - case Some(o) => - obj = o - o.getClass - } + "[Value: NULL]" } + }) + .mkString(", ") - args = readArgs(numArgs, dis) - methods = cls.getMethods - - val selectedMethods = methods.filter(m => m.getName == methodName) - if (selectedMethods.length > 0) { - val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args) - - if (index.isEmpty) { - logWarning( - s"cannot find matching method ${cls}.$methodName. " - + s"Candidates are:") - selectedMethods.foreach { method => - logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") - } - throw new Exception(s"No matched method found for $cls.$methodName") - } - - val ret = selectedMethods(index.get).invoke(obj, args: _*) - - // Write status bit - writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) - } else if (methodName == "") { - // methodName should be "" for constructor - val ctor = cls.getConstructors.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - }.head - - val obj = ctor.newInstance(args: _*) - - writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) - } else { - throw new IllegalArgumentException( - "invalid method " + methodName + " for object " + objId) - } - } catch { - case e: Exception => - val jvmObj = JVMObjectTracker.get(objId) - val jvmObjName = jvmObj match { - case Some(jObj) => jObj.getClass.getName - case None => "NullObject" - } - val argsStr = args - .map(arg => { - if (arg != null) { - s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" - } else { - "[Value: NULL]" - } - }) - .mkString(", ") - - logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") - - if (methods != null) { - logDebug(s"All methods for $jvmObjName:") - methods.foreach(m => logDebug(m.toString)) - } - - writeInt(dos, -1) - writeString(dos, Utils.exceptionString(e.getCause)) + logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") + + if (methods != null) { + logDebug(s"All methods for $jvmObjName:") + methods.foreach(m => logDebug(m.toString)) } - } - // Read a number of arguments from the data input stream - def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { - (0 until numArgs).map { arg => - readObject(dis) - }.toArray + writeInt(dos, -1) + writeString(dos, Utils.exceptionString(e.getCause)) + } + } + + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + readObject(dis) + }.toArray + } + + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false } - // Checks if the arguments passed in args matches the parameter types. - // NOTE: Currently we do exact match. We may add type conversions later. - def matchMethod( - numArgs: Int, - args: Array[java.lang.Object], - parameterTypes: Array[Class[_]]): Boolean = { - if (parameterTypes.length != numArgs) { - return false + for (i <- 0 until numArgs) { + val parameterType = parameterTypes(i) + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType } + } - for (i <- 0 until numArgs) { - val parameterType = parameterTypes(i) + if (!parameterWrapperType.isInstance(args(i))) { + // non primitive types + if (!parameterType.isPrimitive && args(i) != null) { + return false + } + + // primitive types + if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) { + return false + } + } + } + + true + } + + // Find a matching method signature in an array of signatures of constructors + // or methods of the same name according to the passed arguments. Arguments + // may be converted in order to match a signature. + // + // Note that in Java reflection, constructors and normal methods are of different + // classes, and share no parent class that provides methods for reflection uses. + // There is no unified way to handle them in this function. So an array of signatures + // is passed in instead of an array of candidate constructors or methods. + // + // Returns an Option[Int] which is the index of the matched signature in the array. + def findMatchedSignature( + parameterTypesOfMethods: Array[Array[Class[_]]], + args: Array[Object]): Option[Int] = { + val numArgs = args.length + + for (index <- parameterTypesOfMethods.indices) { + val parameterTypes = parameterTypesOfMethods(index) + + if (parameterTypes.length == numArgs) { + var argMatched = true + var i = 0 + while (i < numArgs && argMatched) { + val parameterType = parameterTypes(i) + + if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { + // The case that the parameter type is a Scala Seq and the argument + // is a Java array is considered matching. The array will be converted + // to a Seq later if this method is matched. + } else { var parameterWrapperType = parameterType // Convert native parameters to Object types as args is Array[Object] here if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Long] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType - } + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } } - - if (!parameterWrapperType.isInstance(args(i))) { - // non primitive types - if (!parameterType.isPrimitive && args(i) != null) { - return false - } - - // primitive types - if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) { - return false - } + if ((parameterType.isPrimitive || args(i) != null) && + !parameterWrapperType.isInstance(args(i))) { + argMatched = false } + } + + i = i + 1 } - true - } + if (argMatched) { + // For now, we return the first matching method. + // TODO: find best method in matching methods. + + // Convert args if needed + val parameterTypes = parameterTypesOfMethods(index) - // Find a matching method signature in an array of signatures of constructors - // or methods of the same name according to the passed arguments. Arguments - // may be converted in order to match a signature. - // - // Note that in Java reflection, constructors and normal methods are of different - // classes, and share no parent class that provides methods for reflection uses. - // There is no unified way to handle them in this function. So an array of signatures - // is passed in instead of an array of candidate constructors or methods. - // - // Returns an Option[Int] which is the index of the matched signature in the array. - def findMatchedSignature( - parameterTypesOfMethods: Array[Array[Class[_]]], - args: Array[Object]): Option[Int] = { - val numArgs = args.length - - for (index <- parameterTypesOfMethods.indices) { - val parameterTypes = parameterTypesOfMethods(index) - - if (parameterTypes.length == numArgs) { - var argMatched = true - var i = 0 - while (i < numArgs && argMatched) { - val parameterType = parameterTypes(i) - - if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { - // The case that the parameter type is a Scala Seq and the argument - // is a Java array is considered matching. The array will be converted - // to a Seq later if this method is matched. - } else { - var parameterWrapperType = parameterType - - // Convert native parameters to Object types as args is Array[Object] here - if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Long] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType - } - } - if ((parameterType.isPrimitive || args(i) != null) && - !parameterWrapperType.isInstance(args(i))) { - argMatched = false - } - } - - i = i + 1 - } - - if (argMatched) { - // For now, we return the first matching method. - // TODO: find best method in matching methods. - - // Convert args if needed - val parameterTypes = parameterTypesOfMethods(index) - - for (i <- 0 until numArgs) { - if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { - // Convert a Java array to scala Seq - args(i) = args(i).asInstanceOf[Array[_]].toSeq - } - } - - return Some(index) - } + for (i <- 0 until numArgs) { + if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { + // Convert a Java array to scala Seq + args(i) = args(i).asInstanceOf[Array[_]].toSeq } + } + + return Some(index) } - None + } } + None + } - def logError(id: String, e: Exception): Unit = {} + def logError(id: String, e: Exception): Unit = {} } /** @@ -331,35 +331,35 @@ class DotnetBackendHandler(server: DotnetBackend) */ private object JVMObjectTracker { - // Multiple threads may access objMap and increase objCounter. Because get method return Option, - // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. - private[this] val objMap = new HashMap[String, Object] - private[this] var objCounter: Int = 1 + // Multiple threads may access objMap and increase objCounter. Because get method return Option, + // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. + private[this] val objMap = new HashMap[String, Object] + private[this] var objCounter: Int = 1 - def getObject(id: String): Object = { - synchronized { - objMap(id) - } + def getObject(id: String): Object = { + synchronized { + objMap(id) } + } - def get(id: String): Option[Object] = { - synchronized { - objMap.get(id) - } + def get(id: String): Option[Object] = { + synchronized { + objMap.get(id) } - - def put(obj: Object): String = { - synchronized { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } + } + + def put(obj: Object): String = { + synchronized { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId } + } - def remove(id: String): Option[Object] = { - synchronized { - objMap.remove(id) - } + def remove(id: String): Option[Object] = { + synchronized { + objMap.remove(id) } + } } diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 28c40ecde..6856b872f 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -37,7 +37,7 @@ object ThreadPool { * * @param threadId */ - def deleteThread(threadId: Int) = { + def deleteThread(threadId: Int): Option[ExecutorService] = { getOrCreateExecutor(threadId).shutdown() executors.remove(threadId) } diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 7b5758822..5a83b2b2c 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -83,12 +83,12 @@ class DotnetBackendHandler(server: DotnetBackend) val address = readString(dis) assert(readObjectType(dis) == 'i') val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port); + ThreadPool.run(threadId, () => DotnetBackend.setCallbackClient(address, port)) writeInt(dos, 0) writeType(dos, "void") case "closeCallback" => logInfo("Requesting to close callback client") - DotnetBackend.shutdownCallbackClient() + ThreadPool.run(threadId, DotnetBackend.shutdownCallbackClient) writeInt(dos, 0) writeType(dos, "void") diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 28c40ecde..6856b872f 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -37,7 +37,7 @@ object ThreadPool { * * @param threadId */ - def deleteThread(threadId: Int) = { + def deleteThread(threadId: Int): Option[ExecutorService] = { getOrCreateExecutor(threadId).shutdown() executors.remove(threadId) } From 39424297302faf7194c619f1f265202fdb5dbbd0 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 21 Sep 2020 13:44:43 -0700 Subject: [PATCH 10/42] Added Spark 3 and fixed thread waiting --- .../apache/spark/api/dotnet/ThreadPool.scala | 5 +- .../apache/spark/api/dotnet/ThreadPool.scala | 5 +- .../api/dotnet/DotnetBackendHandler.scala | 45 +++++++++------ .../apache/spark/api/dotnet/ThreadPool.scala | 56 +++++++++++++++++++ 4 files changed, 89 insertions(+), 22 deletions(-) create mode 100644 src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 6856b872f..d33664d23 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -27,9 +27,8 @@ object ThreadPool { val future = executor.submit(new Runnable { override def run(): Unit = task() }) - while (!future.isDone) { - Thread.sleep(1000) - } + + future.get() } /** diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 6856b872f..d33664d23 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -27,9 +27,8 @@ object ThreadPool { val future = executor.submit(new Runnable { override def run(): Unit = task() }) - while (!future.isDone) { - Thread.sleep(1000) - } + + future.get() } /** diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 1cde1d1c5..a0c98bd2d 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,14 +8,11 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.api.dotnet.SerDe._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils - import scala.collection.mutable.HashMap import scala.language.existentials +import org.apache.spark.api.dotnet.SerDe._ + /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. @@ -42,6 +39,7 @@ class DotnetBackendHandler(server: DotnetBackend) // First bit is isStatic val isStatic = readBoolean(dis) + val threadId = readInt(dis) val objId = readString(dis) val methodName = readString(dis) val numArgs = readInt(dis) @@ -57,7 +55,7 @@ class DotnetBackendHandler(server: DotnetBackend) val t = readObjectType(dis) assert(t == 'c') val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) + ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)) writeInt(dos, 0) writeObject(dos, null) } catch { @@ -65,24 +63,37 @@ class DotnetBackendHandler(server: DotnetBackend) logError(s"Removing $objId failed", e) writeInt(dos, -1) } + case "rmThread" => + try { + assert(readObjectType(dis) == 'i') + val threadToDelete = readInt(dis) + ThreadPool.deleteThread(threadToDelete) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing thread $threadId failed", e) + writeInt(dos, -1) + } case "connectCallback" => assert(readObjectType(dis) == 'c') val address = readString(dis) assert(readObjectType(dis) == 'i') val port = readInt(dis) - DotnetBackend.setCallbackClient(address, port); + ThreadPool.run(threadId, () => DotnetBackend.setCallbackClient(address, port)) writeInt(dos, 0) writeType(dos, "void") case "closeCallback" => logInfo("Requesting to close callback client") - DotnetBackend.shutdownCallbackClient() + ThreadPool.run(threadId, DotnetBackend.shutdownCallbackClient) writeInt(dos, 0) writeType(dos, "void") case _ => dos.writeInt(-1) } } else { - handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + ThreadPool + .run(threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) } bos.toByteArray @@ -168,13 +179,15 @@ class DotnetBackendHandler(server: DotnetBackend) case Some(jObj) => jObj.getClass.getName case None => "NullObject" } - val argsStr = args.map(arg => { - if (arg != null) { - s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" - } else { - "[Value: NULL]" - } - }).mkString(", ") + val argsStr = args + .map(arg => { + if (arg != null) { + s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" + } else { + "[Value: NULL]" + } + }) + .mkString(", ") logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala new file mode 100644 index 000000000..d33664d23 --- /dev/null +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -0,0 +1,56 @@ +package org.apache.spark.api.dotnet + +import java.util.concurrent.{ExecutorService, Executors} + +import scala.collection._ + +/** + * Pool of thread executors. There should be a 1-1 correspondence between C# threads + * and Java threads. + */ +object ThreadPool { + + /** + * Map from threadId to corresponding executor. + */ + val executors: concurrent.TrieMap[Int, ExecutorService] = + new concurrent.TrieMap[Int, ExecutorService]() + + /** + * Run some code on a particular thread. + * + * @param threadId + * @param task + */ + def run(threadId: Int, task: () => Unit): Unit = { + val executor = getOrCreateExecutor(threadId) + val future = executor.submit(new Runnable { + override def run(): Unit = task() + }) + + future.get() + } + + /** + * Delete a particular thread. + * + * @param threadId + */ + def deleteThread(threadId: Int): Option[ExecutorService] = { + getOrCreateExecutor(threadId).shutdown() + executors.remove(threadId) + } + + /** + * Get the executor if it exists, otherwise create a new one. + * + * @param id + * @return + */ + private def getOrCreateExecutor(id: Int): ExecutorService = + executors.getOrElse(id, { + val thread = Executors.newSingleThreadExecutor() + executors.put(id, thread) + thread + }) +} From b86eb00e39e332eb07ec606161a0973e6c9360e5 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 21 Sep 2020 14:42:24 -0700 Subject: [PATCH 11/42] Fixed imports --- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index a0c98bd2d..5a83b2b2c 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -11,7 +11,10 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.mutable.HashMap import scala.language.existentials +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.api.dotnet.SerDe._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils /** * Handler for DotnetBackend. From 4eceb5dbab96b8cd4b5c52ef8c48c630839a4c06 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 21 Sep 2020 16:10:10 -0700 Subject: [PATCH 12/42] Clean up ThreadPool --- .../apache/spark/api/dotnet/ThreadPool.scala | 28 ++++++++----------- .../apache/spark/api/dotnet/ThreadPool.scala | 28 ++++++++----------- .../apache/spark/api/dotnet/ThreadPool.scala | 28 ++++++++----------- 3 files changed, 36 insertions(+), 48 deletions(-) diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index d33664d23..f3837fe22 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -2,7 +2,7 @@ package org.apache.spark.api.dotnet import java.util.concurrent.{ExecutorService, Executors} -import scala.collection._ +import scala.collection.mutable /** * Pool of thread executors. There should be a 1-1 correspondence between C# threads @@ -13,14 +13,14 @@ object ThreadPool { /** * Map from threadId to corresponding executor. */ - val executors: concurrent.TrieMap[Int, ExecutorService] = - new concurrent.TrieMap[Int, ExecutorService]() + val executors: mutable.HashMap[Int, ExecutorService] = + new mutable.HashMap[Int, ExecutorService]() /** * Run some code on a particular thread. * - * @param threadId - * @param task + * @param threadId Integer id of the thread. + * @param task Function to run on the thread. */ def run(threadId: Int, task: () => Unit): Unit = { val executor = getOrCreateExecutor(threadId) @@ -34,23 +34,19 @@ object ThreadPool { /** * Delete a particular thread. * - * @param threadId + * @param threadId Integer id of the thread. */ - def deleteThread(threadId: Int): Option[ExecutorService] = { - getOrCreateExecutor(threadId).shutdown() - executors.remove(threadId) + def deleteThread(threadId: Int): Unit = synchronized { + executors.remove(threadId).foreach(_.shutdown) } /** * Get the executor if it exists, otherwise create a new one. * - * @param id + * @param id Integer id of the thread. * @return */ - private def getOrCreateExecutor(id: Int): ExecutorService = - executors.getOrElse(id, { - val thread = Executors.newSingleThreadExecutor() - executors.put(id, thread) - thread - }) + private def getOrCreateExecutor(id: Int): ExecutorService = synchronized { + executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor) + } } diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index d33664d23..f3837fe22 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -2,7 +2,7 @@ package org.apache.spark.api.dotnet import java.util.concurrent.{ExecutorService, Executors} -import scala.collection._ +import scala.collection.mutable /** * Pool of thread executors. There should be a 1-1 correspondence between C# threads @@ -13,14 +13,14 @@ object ThreadPool { /** * Map from threadId to corresponding executor. */ - val executors: concurrent.TrieMap[Int, ExecutorService] = - new concurrent.TrieMap[Int, ExecutorService]() + val executors: mutable.HashMap[Int, ExecutorService] = + new mutable.HashMap[Int, ExecutorService]() /** * Run some code on a particular thread. * - * @param threadId - * @param task + * @param threadId Integer id of the thread. + * @param task Function to run on the thread. */ def run(threadId: Int, task: () => Unit): Unit = { val executor = getOrCreateExecutor(threadId) @@ -34,23 +34,19 @@ object ThreadPool { /** * Delete a particular thread. * - * @param threadId + * @param threadId Integer id of the thread. */ - def deleteThread(threadId: Int): Option[ExecutorService] = { - getOrCreateExecutor(threadId).shutdown() - executors.remove(threadId) + def deleteThread(threadId: Int): Unit = synchronized { + executors.remove(threadId).foreach(_.shutdown) } /** * Get the executor if it exists, otherwise create a new one. * - * @param id + * @param id Integer id of the thread. * @return */ - private def getOrCreateExecutor(id: Int): ExecutorService = - executors.getOrElse(id, { - val thread = Executors.newSingleThreadExecutor() - executors.put(id, thread) - thread - }) + private def getOrCreateExecutor(id: Int): ExecutorService = synchronized { + executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor) + } } diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index d33664d23..f3837fe22 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -2,7 +2,7 @@ package org.apache.spark.api.dotnet import java.util.concurrent.{ExecutorService, Executors} -import scala.collection._ +import scala.collection.mutable /** * Pool of thread executors. There should be a 1-1 correspondence between C# threads @@ -13,14 +13,14 @@ object ThreadPool { /** * Map from threadId to corresponding executor. */ - val executors: concurrent.TrieMap[Int, ExecutorService] = - new concurrent.TrieMap[Int, ExecutorService]() + val executors: mutable.HashMap[Int, ExecutorService] = + new mutable.HashMap[Int, ExecutorService]() /** * Run some code on a particular thread. * - * @param threadId - * @param task + * @param threadId Integer id of the thread. + * @param task Function to run on the thread. */ def run(threadId: Int, task: () => Unit): Unit = { val executor = getOrCreateExecutor(threadId) @@ -34,23 +34,19 @@ object ThreadPool { /** * Delete a particular thread. * - * @param threadId + * @param threadId Integer id of the thread. */ - def deleteThread(threadId: Int): Option[ExecutorService] = { - getOrCreateExecutor(threadId).shutdown() - executors.remove(threadId) + def deleteThread(threadId: Int): Unit = synchronized { + executors.remove(threadId).foreach(_.shutdown) } /** * Get the executor if it exists, otherwise create a new one. * - * @param id + * @param id Integer id of the thread. * @return */ - private def getOrCreateExecutor(id: Int): ExecutorService = - executors.getOrElse(id, { - val thread = Executors.newSingleThreadExecutor() - executors.put(id, thread) - thread - }) + private def getOrCreateExecutor(id: Int): ExecutorService = synchronized { + executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor) + } } From e0f8d41d6398f0193f26d42046574a9da45e5db0 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 21 Sep 2020 16:56:13 -0700 Subject: [PATCH 13/42] Add ActiveSession APIs --- .../IpcTests/Sql/SparkSessionTests.cs | 6 ++- .../Microsoft.Spark/Sql/SparkSession.cs | 37 +++++++++++++++++-- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs index 18accd1e2..d3119437a 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs @@ -36,6 +36,10 @@ public void TestSignaturesV2_3_X() Assert.IsType(SparkSession.Builder()); + SparkSession.ClearActiveSession(); + SparkSession.SetActiveSession(_spark); + Assert.IsType(SparkSession.GetActiveSession()); + SparkSession.ClearDefaultSession(); SparkSession.SetDefaultSession(_spark); Assert.IsType(SparkSession.GetDefaultSession()); @@ -75,7 +79,7 @@ public void TestSignaturesV2_4_X() /// [Fact] public void TestCreateDataFrame() - { + { // Calling CreateDataFrame with schema { var data = new List diff --git a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs index a41f585ec..c09ddaafc 100644 --- a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs +++ b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs @@ -61,10 +61,39 @@ internal SparkSession(JvmObjectReference jvmObject) /// Builder object public static Builder Builder() => new Builder(); - /// Note that *ActiveSession() APIs are not exposed because these APIs work with a - /// thread-local variable, which stores the session variable. Since the Netty server - /// that handles the requests is multi-threaded, any thread can invoke these APIs, - /// resulting in unexpected behaviors if different threads are used. + /// + /// Changes the SparkSession that will be returned in this thread and its children when + /// SparkSession.GetOrCreate() is called. This can be used to ensure that a given thread + /// receives a SparkSession with an isolated session, instead of the global (first created) + /// context. + /// + /// + public static void SetActiveSession(SparkSession session) => + session._jvmObject.Jvm.CallStaticJavaMethod( + s_sparkSessionClassName, "setActiveSession", session); + + /// + /// Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will + /// return the first created context instead of a thread-local override. + /// + public static void ClearActiveSession() => + SparkEnvironment.JvmBridge.CallStaticJavaMethod( + s_sparkSessionClassName, "clearActiveSession"); + + /// + /// Returns the active SparkSession for the current thread, returned by the builder. + /// + /// Return null, when calling this function on executors + public static SparkSession GetActiveSession() + { + var optionalSession = new Option( + (JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod( + s_sparkSessionClassName, "getActiveSession")); + + return optionalSession.IsDefined() + ? new SparkSession((JvmObjectReference)optionalSession.Get()) + : null; + } /// /// Sets the default SparkSession that is returned by the builder. From e31fb005ead9215838602999731533a74f30a909 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Tue, 22 Sep 2020 13:59:24 -0700 Subject: [PATCH 14/42] Refactor JvmThreadPool into separate class --- .../IpcTests/JvmThreadPoolTests.cs | 95 +++++++++++++++++++ .../Microsoft.Spark/Interop/Ipc/JvmBridge.cs | 29 +----- .../Interop/Ipc/JvmThreadPool.cs | 68 +++++++++++++ 3 files changed, 167 insertions(+), 25 deletions(-) create mode 100644 src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs create mode 100644 src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs new file mode 100644 index 000000000..4c32914ea --- /dev/null +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs @@ -0,0 +1,95 @@ +using System; +using System.Threading; +using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Sql; +using Xunit; + +namespace Microsoft.Spark.E2ETest.IpcTests +{ + [Collection("Spark E2E Tests")] + public class JvmThreadPoolTests + { + private readonly SparkSession _spark; + private readonly IJvmBridge _jvmBridge; + + public JvmThreadPoolTests(SparkFixture fixture) + { + _spark = fixture.Spark; + _jvmBridge = ((IJvmObjectReferenceProvider)_spark).Reference.Jvm; + } + + /// + /// Test that the active SparkSession is thread-specific. + /// + [Fact] + public void TestThreadLocalSessions() + { + SparkSession.ClearActiveSession(); + + void testChildThread(string appName) + { + var thread = new Thread(() => + { + Assert.Null(SparkSession.GetActiveSession()); + + SparkSession.SetActiveSession(SparkSession.Builder().AppName(appName).GetOrCreate()); + + // Since we are in the child thread, GetActiveSession() should return the child SparkSession. + var activeSession = SparkSession.GetActiveSession(); + Assert.NotNull(activeSession); + Assert.Equal(appName, activeSession.Conf().Get("spark.app.name", null)); + }); + + thread.Start(); + while (thread.IsAlive) + { + Thread.Sleep(1000); + } + } + + for (var i = 0; i < 5; i++) + { + testChildThread(i.ToString()); + } + + Assert.Null(SparkSession.GetActiveSession()); + } + + /// + /// Add and remove a thread via the JvmThreadPool. + /// + [Fact] + public void TestAddRemoveThread() + { + var threadPool = new JvmThreadPool(_jvmBridge, TimeSpan.FromMinutes(30)); + + var thread = new Thread(() => _spark.Sql("SELECT TRUE")); + thread.Start(); + + Assert.True(threadPool.TryAddThread(thread)); + // Subsequent call should return false, because the thread has already been added. + Assert.False(threadPool.TryAddThread(thread)); + + thread.Join(); + + Assert.True(threadPool.TryRemoveThread(thread.ManagedThreadId)); + + // Subsequent call should return false, because the thread has already been removed. + Assert.False(threadPool.TryRemoveThread(thread.ManagedThreadId)); + } + + /// + /// Create a Spark worker thread in the JVM ThreadPool then remove it directly through + /// the JvmBridge. + /// + [Fact] + public void TestThreadRm() + { + // Create a thread and ensure that it is initialized in the JVM ThreadPool. + var thread = new Thread(() => _spark.Sql("SELECT TRUE")); + thread.Start(); + thread.Join(); + _jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId); + } + } +} diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index e28aa66dd..b67efd532 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -36,8 +36,8 @@ internal sealed class JvmBridge : IJvmBridge private readonly ILoggerService _logger = LoggerServiceFactory.GetLogger(typeof(JvmBridge)); private readonly int _portNumber; - private readonly ConcurrentDictionary _activeThreads; - private readonly Thread _activeThreadMonitor; + private readonly JvmThreadPool _jvmThreadPool; + internal JvmBridge(int portNumber) { @@ -49,28 +49,7 @@ internal JvmBridge(int portNumber) _portNumber = portNumber; _logger.LogInfo($"JvMBridge port is {portNumber}"); - _activeThreads = new ConcurrentDictionary(); - _activeThreadMonitor = new Thread(delegate () - { - using var timer = new Timer((state) => - { - foreach (var threadId in _activeThreads.Keys) - { - if (_activeThreads.TryRemove(threadId, out Thread thread)) - { - if (thread.IsAlive) - { - _activeThreads.TryAdd(threadId, thread); - } - else - { - CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId); - } - } - } - }, null, 0, 30000); - }); - _activeThreadMonitor.Start(); + _jvmThreadPool = new JvmThreadPool(this, TimeSpan.FromMinutes(30)); } private ISocketWrapper GetConnection() @@ -204,7 +183,7 @@ private object CallJavaMethod( (int)payloadMemoryStream.Position); outputStream.Flush(); - _activeThreads.TryAdd(thread.ManagedThreadId, thread); + _jvmThreadPool.TryAddThread(thread); Stream inputStream = socket.InputStream; int isMethodCallFailed = SerDe.ReadInt32(inputStream); diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs new file mode 100644 index 000000000..ebe776558 --- /dev/null +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Concurrent; +using System.Threading; + +namespace Microsoft.Spark.Interop.Ipc +{ + /// + /// This class corresponds to the ThreadPool we maintain on the JVM side. This class keeps + /// track of which .NET threads are still alive, and issues an rmThread command if a thread is not. + /// + internal class JvmThreadPool + { + private readonly IJvmBridge _jvmBridge; + private readonly ConcurrentDictionary _activeThreads; + private readonly Thread _activeThreadMonitor; + + /// + /// Construct the JvmThreadPool. + /// + /// The JvmBridge used to call JVM methods. + /// The interval to GC finished threads. + public JvmThreadPool(IJvmBridge jvmBridge, TimeSpan threadGcInterval) + { + _jvmBridge = jvmBridge; + _activeThreads = new ConcurrentDictionary(); + _activeThreadMonitor = new Thread(delegate () + { + using var timer = new Timer((state) => + { + foreach (Thread thread in _activeThreads.Values) + { + if (!thread.IsAlive) + { + TryRemoveThread(thread.ManagedThreadId); + } + } + }, null, 0, (int)threadGcInterval.TotalMilliseconds); + }); + _activeThreadMonitor.Start(); + } + + /// + /// Try to add a thread to the pool. + /// + /// The thread to add. + /// True if success, false if already added. + public bool TryAddThread(Thread thread) + { + return _activeThreads.TryAdd(thread.ManagedThreadId, thread); + } + + /// + /// Try to remove a thread. + /// + /// The ID of the thread to remove. + /// True if success, false if the thread cannot be found. + public bool TryRemoveThread(int managedThreadId) + { + if (_activeThreads.TryRemove(managedThreadId, out Thread thread)) + { + _jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId); + return true; + } + + return false; + } + } +} From 1784d92c4571eb598200d0cc9ed191893e624245 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Tue, 22 Sep 2020 14:15:18 -0700 Subject: [PATCH 15/42] Clean-up --- .../IpcTests/JvmThreadPoolTests.cs | 12 +++++++++--- src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs | 1 - .../Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs | 12 +++++++++--- .../Microsoft.Spark/Interop/Ipc/PayloadHelper.cs | 1 - .../org/apache/spark/api/dotnet/ThreadPool.scala | 6 ++++++ .../org/apache/spark/api/dotnet/ThreadPool.scala | 6 ++++++ .../org/apache/spark/api/dotnet/ThreadPool.scala | 6 ++++++ 7 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs index 4c32914ea..84a20e6d6 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Threading; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Sql; @@ -32,9 +36,11 @@ void testChildThread(string appName) { Assert.Null(SparkSession.GetActiveSession()); - SparkSession.SetActiveSession(SparkSession.Builder().AppName(appName).GetOrCreate()); + SparkSession.SetActiveSession( + SparkSession.Builder().AppName(appName).GetOrCreate()); - // Since we are in the child thread, GetActiveSession() should return the child SparkSession. + // Since we are in the child thread, GetActiveSession() should return the child + // SparkSession. var activeSession = SparkSession.GetActiveSession(); Assert.NotNull(activeSession); Assert.Equal(appName, activeSession.Conf().Get("spark.app.name", null)); diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index b67efd532..4d35c35b3 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -38,7 +38,6 @@ internal sealed class JvmBridge : IJvmBridge private readonly int _portNumber; private readonly JvmThreadPool _jvmThreadPool; - internal JvmBridge(int portNumber) { if (portNumber == 0) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs index ebe776558..7e1026c4b 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Concurrent; using System.Threading; @@ -6,7 +10,8 @@ namespace Microsoft.Spark.Interop.Ipc { /// /// This class corresponds to the ThreadPool we maintain on the JVM side. This class keeps - /// track of which .NET threads are still alive, and issues an rmThread command if a thread is not. + /// track of which .NET threads are still alive, and issues an rmThread command if a thread is + /// not. /// internal class JvmThreadPool { @@ -58,7 +63,8 @@ public bool TryRemoveThread(int managedThreadId) { if (_activeThreads.TryRemove(managedThreadId, out Thread thread)) { - _jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId); + _jvmBridge.CallStaticJavaMethod( + "DotnetHandler", "rmThread", thread.ManagedThreadId); return true; } diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs index 6aedeced1..569744713 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs @@ -8,7 +8,6 @@ using System.Collections.Generic; using System.IO; using System.Linq; -using System.Threading; using Microsoft.Spark.Sql; using Microsoft.Spark.Sql.Types; diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index f3837fe22..fe5b0dbee 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -1,3 +1,9 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + package org.apache.spark.api.dotnet import java.util.concurrent.{ExecutorService, Executors} diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index f3837fe22..fe5b0dbee 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -1,3 +1,9 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + package org.apache.spark.api.dotnet import java.util.concurrent.{ExecutorService, Executors} diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index f3837fe22..fe5b0dbee 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -1,3 +1,9 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + package org.apache.spark.api.dotnet import java.util.concurrent.{ExecutorService, Executors} From 5c45f2e61aad01c5c54732e1a7ac74361d518856 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Thu, 24 Sep 2020 16:56:27 -0700 Subject: [PATCH 16/42] Update src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala Co-authored-by: Steve Suh --- .../src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index fe5b0dbee..1307903fb 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -19,7 +19,7 @@ object ThreadPool { /** * Map from threadId to corresponding executor. */ - val executors: mutable.HashMap[Int, ExecutorService] = + private val executors: mutable.HashMap[Int, ExecutorService] = new mutable.HashMap[Int, ExecutorService]() /** From 6202f57d5c90421ab626af97afc83498e01149cd Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Thu, 24 Sep 2020 16:57:41 -0700 Subject: [PATCH 17/42] Make executors private --- .../src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala | 2 +- .../src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index fe5b0dbee..1307903fb 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -19,7 +19,7 @@ object ThreadPool { /** * Map from threadId to corresponding executor. */ - val executors: mutable.HashMap[Int, ExecutorService] = + private val executors: mutable.HashMap[Int, ExecutorService] = new mutable.HashMap[Int, ExecutorService]() /** diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index fe5b0dbee..1307903fb 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -19,7 +19,7 @@ object ThreadPool { /** * Map from threadId to corresponding executor. */ - val executors: mutable.HashMap[Int, ExecutorService] = + private val executors: mutable.HashMap[Int, ExecutorService] = new mutable.HashMap[Int, ExecutorService]() /** From ac380c5595510cdb3b44de0ac33c5365369f393b Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Thu, 24 Sep 2020 17:02:04 -0700 Subject: [PATCH 18/42] Update src/csharp/Microsoft.Spark/Sql/SparkSession.cs Co-authored-by: Steve Suh --- src/csharp/Microsoft.Spark/Sql/SparkSession.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs index c09ddaafc..f7ab53c89 100644 --- a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs +++ b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs @@ -63,9 +63,9 @@ internal SparkSession(JvmObjectReference jvmObject) /// /// Changes the SparkSession that will be returned in this thread and its children when - /// SparkSession.GetOrCreate() is called. This can be used to ensure that a given thread - /// receives a SparkSession with an isolated session, instead of the global (first created) - /// context. + /// is called. This can be used to ensure that a given + /// thread receives a SparkSession with an isolated session, instead of the global + /// (first created) context. /// /// public static void SetActiveSession(SparkSession session) => From 8f81a4f08fc0fcff22b466af829d88025f3b63e5 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Thu, 24 Sep 2020 17:02:30 -0700 Subject: [PATCH 19/42] Update src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs Co-authored-by: Steve Suh --- src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs index 7e1026c4b..e18e182be 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs @@ -32,11 +32,11 @@ public JvmThreadPool(IJvmBridge jvmBridge, TimeSpan threadGcInterval) { using var timer = new Timer((state) => { - foreach (Thread thread in _activeThreads.Values) + foreach (KeyValuePair kvp in _activeThreads) { - if (!thread.IsAlive) + if (!kvp.Value.IsAlive) { - TryRemoveThread(thread.ManagedThreadId); + TryRemoveThread(kvp.Key); } } }, null, 0, (int)threadGcInterval.TotalMilliseconds); From 5b3889023d1b4e6ef77a3922e51326d1167dc610 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Thu, 24 Sep 2020 17:02:55 -0700 Subject: [PATCH 20/42] Update src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs Co-authored-by: Steve Suh --- src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs index e18e182be..1359b37e6 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs @@ -39,7 +39,7 @@ public JvmThreadPool(IJvmBridge jvmBridge, TimeSpan threadGcInterval) TryRemoveThread(kvp.Key); } } - }, null, 0, (int)threadGcInterval.TotalMilliseconds); + }, null, threadGcInterval, threadGcInterval); }); _activeThreadMonitor.Start(); } From 7ec524b86ff7823215b8237e8542709216690421 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Thu, 24 Sep 2020 17:03:14 -0700 Subject: [PATCH 21/42] Update src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs Co-authored-by: Steve Suh --- src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs index 1359b37e6..36f72ac98 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs @@ -64,7 +64,7 @@ public bool TryRemoveThread(int managedThreadId) if (_activeThreads.TryRemove(managedThreadId, out Thread thread)) { _jvmBridge.CallStaticJavaMethod( - "DotnetHandler", "rmThread", thread.ManagedThreadId); + "DotnetHandler", "rmThread", managedThreadId); return true; } From b1d3da2f21fad9771ba9d87a43cb6a20f53e1b55 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Thu, 24 Sep 2020 17:03:25 -0700 Subject: [PATCH 22/42] Update src/csharp/Microsoft.Spark/Sql/SparkSession.cs Co-authored-by: Steve Suh --- src/csharp/Microsoft.Spark/Sql/SparkSession.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs index f7ab53c89..58a118cc1 100644 --- a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs +++ b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs @@ -73,8 +73,9 @@ public static void SetActiveSession(SparkSession session) => s_sparkSessionClassName, "setActiveSession", session); /// - /// Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will - /// return the first created context instead of a thread-local override. + /// Clears the active SparkSession for current thread. Subsequent calls to + /// will return the first created context + /// instead of a thread-local override. /// public static void ClearActiveSession() => SparkEnvironment.JvmBridge.CallStaticJavaMethod( From a45a79ddb294f397acf4b2bd1f29f6d21ef0c69e Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Fri, 25 Sep 2020 09:54:06 -0700 Subject: [PATCH 23/42] Param documentation --- src/csharp/Microsoft.Spark/Sql/SparkSession.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs index 58a118cc1..082d95ce9 100644 --- a/src/csharp/Microsoft.Spark/Sql/SparkSession.cs +++ b/src/csharp/Microsoft.Spark/Sql/SparkSession.cs @@ -62,12 +62,12 @@ internal SparkSession(JvmObjectReference jvmObject) public static Builder Builder() => new Builder(); /// - /// Changes the SparkSession that will be returned in this thread and its children when + /// Changes the SparkSession that will be returned in this thread when /// is called. This can be used to ensure that a given /// thread receives a SparkSession with an isolated session, instead of the global /// (first created) context. /// - /// + /// SparkSession object public static void SetActiveSession(SparkSession session) => session._jvmObject.Jvm.CallStaticJavaMethod( s_sparkSessionClassName, "setActiveSession", session); From f265640166f90a0fe1c601ef612f1303746d5788 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Fri, 25 Sep 2020 09:59:43 -0700 Subject: [PATCH 24/42] Refactor JvmThreadPool --- .../Interop/Ipc/JvmThreadPool.cs | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs index 36f72ac98..a9ef1ef7b 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; using System.Threading; namespace Microsoft.Spark.Interop.Ipc @@ -13,11 +14,11 @@ namespace Microsoft.Spark.Interop.Ipc /// track of which .NET threads are still alive, and issues an rmThread command if a thread is /// not. /// - internal class JvmThreadPool + internal class JvmThreadPool : IDisposable { private readonly IJvmBridge _jvmBridge; private readonly ConcurrentDictionary _activeThreads; - private readonly Thread _activeThreadMonitor; + private readonly Timer _activeThreadMonitor; /// /// Construct the JvmThreadPool. @@ -28,20 +29,17 @@ public JvmThreadPool(IJvmBridge jvmBridge, TimeSpan threadGcInterval) { _jvmBridge = jvmBridge; _activeThreads = new ConcurrentDictionary(); - _activeThreadMonitor = new Thread(delegate () - { - using var timer = new Timer((state) => - { - foreach (KeyValuePair kvp in _activeThreads) - { - if (!kvp.Value.IsAlive) - { - TryRemoveThread(kvp.Key); - } - } - }, null, threadGcInterval, threadGcInterval); - }); - _activeThreadMonitor.Start(); + _activeThreadMonitor = new Timer( + (state) => GarbageCollectThreads(), null, threadGcInterval, threadGcInterval); + } + + /// + /// Dispose of the thread monitor and run a final round of thread GC. + /// + public void Dispose() + { + _activeThreadMonitor.Dispose(); + GarbageCollectThreads(); } /// @@ -70,5 +68,19 @@ public bool TryRemoveThread(int managedThreadId) return false; } + + /// + /// Remove any threads that are no longer active. + /// + private void GarbageCollectThreads() + { + foreach (KeyValuePair kvp in _activeThreads) + { + if (!kvp.Value.IsAlive) + { + TryRemoveThread(kvp.Key); + } + } + } } } From d4610a1ca5e9e873b4b0c9b82505f8d984a3635c Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Fri, 25 Sep 2020 10:01:17 -0700 Subject: [PATCH 25/42] Dispose of the threadpool --- src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index 4d35c35b3..97903d42d 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -418,6 +418,7 @@ private object ReadCollection(Stream s) public void Dispose() { + _jvmThreadPool.Dispose(); while (_sockets.TryDequeue(out ISocketWrapper socket)) { if (socket != null) From 4922617c25a40035ca9016c3e9dee27e3df85192 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Fri, 25 Sep 2020 10:07:09 -0700 Subject: [PATCH 26/42] Add mising doc --- .../src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala | 2 +- .../src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala | 2 +- .../src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 1307903fb..6b1d7888c 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -50,7 +50,7 @@ object ThreadPool { * Get the executor if it exists, otherwise create a new one. * * @param id Integer id of the thread. - * @return + * @return The new or existing executor with the given id. */ private def getOrCreateExecutor(id: Int): ExecutorService = synchronized { executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor) diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 1307903fb..6b1d7888c 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -50,7 +50,7 @@ object ThreadPool { * Get the executor if it exists, otherwise create a new one. * * @param id Integer id of the thread. - * @return + * @return The new or existing executor with the given id. */ private def getOrCreateExecutor(id: Int): ExecutorService = synchronized { executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor) diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 1307903fb..6b1d7888c 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -50,7 +50,7 @@ object ThreadPool { * Get the executor if it exists, otherwise create a new one. * * @param id Integer id of the thread. - * @return + * @return The new or existing executor with the given id. */ private def getOrCreateExecutor(id: Int): ExecutorService = synchronized { executors.getOrElseUpdate(id, Executors.newSingleThreadExecutor) From 4d213623b4fc5713130738e7ae26a6aa1867ac83 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Fri, 25 Sep 2020 10:08:25 -0700 Subject: [PATCH 27/42] Change import order --- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 6 +++--- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 6 +++--- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 5a83b2b2c..ddcf61fcd 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import scala.collection.mutable.HashMap -import scala.language.existentials - import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils +import scala.collection.mutable.HashMap +import scala.language.existentials + /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 5a83b2b2c..ddcf61fcd 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import scala.collection.mutable.HashMap -import scala.language.existentials - import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils +import scala.collection.mutable.HashMap +import scala.language.existentials + /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index 5a83b2b2c..ddcf61fcd 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import scala.collection.mutable.HashMap -import scala.language.existentials - import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils +import scala.collection.mutable.HashMap +import scala.language.existentials + /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. From f71a1fbb23c12f8283be2f6d0e65df81c300ee64 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Fri, 25 Sep 2020 14:46:46 -0700 Subject: [PATCH 28/42] Don't need threadpool for removing object from tracker --- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 2 +- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 2 +- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index ddcf61fcd..eb170395b 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -58,7 +58,7 @@ class DotnetBackendHandler(server: DotnetBackend) val t = readObjectType(dis) assert(t == 'c') val objToRemove = readString(dis) - ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)) + JVMObjectTracker.remove(objToRemove) writeInt(dos, 0) writeObject(dos, null) } catch { diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index ddcf61fcd..eb170395b 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -58,7 +58,7 @@ class DotnetBackendHandler(server: DotnetBackend) val t = readObjectType(dis) assert(t == 'c') val objToRemove = readString(dis) - ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)) + JVMObjectTracker.remove(objToRemove) writeInt(dos, 0) writeObject(dos, null) } catch { diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index ddcf61fcd..eb170395b 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -58,7 +58,7 @@ class DotnetBackendHandler(server: DotnetBackend) val t = readObjectType(dis) assert(t == 'c') val objToRemove = readString(dis) - ThreadPool.run(threadId, () => JVMObjectTracker.remove(objToRemove)) + JVMObjectTracker.remove(objToRemove) writeInt(dos, 0) writeObject(dos, null) } catch { From fcf0ccf3ecd854efd6588a002bff067518d2541e Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Sun, 27 Sep 2020 16:35:46 -0700 Subject: [PATCH 29/42] Don't run callback client in thread pool --- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 4 ++-- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 4 ++-- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index eb170395b..da9839d56 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -83,12 +83,12 @@ class DotnetBackendHandler(server: DotnetBackend) val address = readString(dis) assert(readObjectType(dis) == 'i') val port = readInt(dis) - ThreadPool.run(threadId, () => DotnetBackend.setCallbackClient(address, port)) + DotnetBackend.setCallbackClient(address, port) writeInt(dos, 0) writeType(dos, "void") case "closeCallback" => logInfo("Requesting to close callback client") - ThreadPool.run(threadId, DotnetBackend.shutdownCallbackClient) + DotnetBackend.shutdownCallbackClient() writeInt(dos, 0) writeType(dos, "void") diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index eb170395b..da9839d56 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -83,12 +83,12 @@ class DotnetBackendHandler(server: DotnetBackend) val address = readString(dis) assert(readObjectType(dis) == 'i') val port = readInt(dis) - ThreadPool.run(threadId, () => DotnetBackend.setCallbackClient(address, port)) + DotnetBackend.setCallbackClient(address, port) writeInt(dos, 0) writeType(dos, "void") case "closeCallback" => logInfo("Requesting to close callback client") - ThreadPool.run(threadId, DotnetBackend.shutdownCallbackClient) + DotnetBackend.shutdownCallbackClient() writeInt(dos, 0) writeType(dos, "void") diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index eb170395b..da9839d56 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -83,12 +83,12 @@ class DotnetBackendHandler(server: DotnetBackend) val address = readString(dis) assert(readObjectType(dis) == 'i') val port = readInt(dis) - ThreadPool.run(threadId, () => DotnetBackend.setCallbackClient(address, port)) + DotnetBackend.setCallbackClient(address, port) writeInt(dos, 0) writeType(dos, "void") case "closeCallback" => logInfo("Requesting to close callback client") - ThreadPool.run(threadId, DotnetBackend.shutdownCallbackClient) + DotnetBackend.shutdownCallbackClient() writeInt(dos, 0) writeType(dos, "void") From ac95655c32cfbe2c50f62dafc21592875c442ab8 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Sun, 27 Sep 2020 19:49:14 -0700 Subject: [PATCH 30/42] Formatting --- .../Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs index a9ef1ef7b..30be0a3e5 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs @@ -24,13 +24,13 @@ internal class JvmThreadPool : IDisposable /// Construct the JvmThreadPool. /// /// The JvmBridge used to call JVM methods. - /// The interval to GC finished threads. - public JvmThreadPool(IJvmBridge jvmBridge, TimeSpan threadGcInterval) + /// The interval to GC finished threads. + public JvmThreadPool(IJvmBridge jvmBridge, TimeSpan threadGCInterval) { _jvmBridge = jvmBridge; _activeThreads = new ConcurrentDictionary(); _activeThreadMonitor = new Timer( - (state) => GarbageCollectThreads(), null, threadGcInterval, threadGcInterval); + (state) => GarbageCollectThreads(), null, threadGCInterval, threadGCInterval); } /// @@ -47,10 +47,8 @@ public void Dispose() /// /// The thread to add. /// True if success, false if already added. - public bool TryAddThread(Thread thread) - { - return _activeThreads.TryAdd(thread.ManagedThreadId, thread); - } + public bool TryAddThread(Thread thread) => + _activeThreads.TryAdd(thread.ManagedThreadId, thread); /// /// Try to remove a thread. @@ -59,7 +57,7 @@ public bool TryAddThread(Thread thread) /// True if success, false if the thread cannot be found. public bool TryRemoveThread(int managedThreadId) { - if (_activeThreads.TryRemove(managedThreadId, out Thread thread)) + if (_activeThreads.TryRemove(managedThreadId, out _)) { _jvmBridge.CallStaticJavaMethod( "DotnetHandler", "rmThread", managedThreadId); From 7e6b7bb97566c3c9ae348cb03bdcefe4beb7d70e Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 28 Sep 2020 11:12:23 -0700 Subject: [PATCH 31/42] Just join the thread --- .../Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs index 84a20e6d6..60e3458c3 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs @@ -46,11 +46,7 @@ void testChildThread(string appName) Assert.Equal(appName, activeSession.Conf().Get("spark.app.name", null)); }); - thread.Start(); - while (thread.IsAlive) - { - Thread.Sleep(1000); - } + thread.Join(); } for (var i = 0; i < 5; i++) From 5abc17a626926ff008a19546c0232e7cb9cee18e Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 28 Sep 2020 11:13:37 -0700 Subject: [PATCH 32/42] Update src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs Co-authored-by: Steve Suh --- src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs index 30be0a3e5..646cb39fe 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs @@ -59,8 +59,7 @@ public bool TryRemoveThread(int managedThreadId) { if (_activeThreads.TryRemove(managedThreadId, out _)) { - _jvmBridge.CallStaticJavaMethod( - "DotnetHandler", "rmThread", managedThreadId); + _jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", managedThreadId); return true; } From ec0661269dbd1900aa79402209d6d19460bbe1e2 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 28 Sep 2020 12:55:14 -0700 Subject: [PATCH 33/42] Fixed: start the thread --- .../Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs index 60e3458c3..d41cb7460 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs @@ -46,6 +46,7 @@ void testChildThread(string appName) Assert.Equal(appName, activeSession.Conf().Get("spark.app.name", null)); }); + thread.Start(); thread.Join(); } From 70ab43a8d6b2bfa9e7d5bdec35e2d20ae2579c17 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 28 Sep 2020 18:11:49 -0700 Subject: [PATCH 34/42] Renamed the JvmThreadPool --- ... => JvmThreadPoolGarbageCollectorTests.cs} | 15 +- .../Microsoft.Spark/Interop/Ipc/JvmBridge.cs | 4 +- .../Interop/Ipc/JvmThreadPool.cs | 83 ----------- .../Ipc/JvmThreadPoolGarbageCollector.cs | 136 ++++++++++++++++++ 4 files changed, 143 insertions(+), 95 deletions(-) rename src/csharp/Microsoft.Spark.E2ETest/IpcTests/{JvmThreadPoolTests.cs => JvmThreadPoolGarbageCollectorTests.cs} (84%) delete mode 100644 src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs create mode 100644 src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGarbageCollector.cs diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGarbageCollectorTests.cs similarity index 84% rename from src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs rename to src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGarbageCollectorTests.cs index d41cb7460..1b2402a33 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGarbageCollectorTests.cs @@ -11,12 +11,12 @@ namespace Microsoft.Spark.E2ETest.IpcTests { [Collection("Spark E2E Tests")] - public class JvmThreadPoolTests + public class JvmThreadPoolGarbageCollectorTests { private readonly SparkSession _spark; private readonly IJvmBridge _jvmBridge; - public JvmThreadPoolTests(SparkFixture fixture) + public JvmThreadPoolGarbageCollectorTests(SparkFixture fixture) { _spark = fixture.Spark; _jvmBridge = ((IJvmObjectReferenceProvider)_spark).Reference.Jvm; @@ -59,12 +59,12 @@ void testChildThread(string appName) } /// - /// Add and remove a thread via the JvmThreadPool. + /// Monitor a thread via the JvmThreadPoolGarbageCollector. /// [Fact] - public void TestAddRemoveThread() + public void TestMonitorThread() { - var threadPool = new JvmThreadPool(_jvmBridge, TimeSpan.FromMinutes(30)); + var threadPool = new JvmThreadPoolGarbageCollector(_jvmBridge, TimeSpan.FromMinutes(30)); var thread = new Thread(() => _spark.Sql("SELECT TRUE")); thread.Start(); @@ -74,11 +74,6 @@ public void TestAddRemoveThread() Assert.False(threadPool.TryAddThread(thread)); thread.Join(); - - Assert.True(threadPool.TryRemoveThread(thread.ManagedThreadId)); - - // Subsequent call should return false, because the thread has already been removed. - Assert.False(threadPool.TryRemoveThread(thread.ManagedThreadId)); } /// diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index 97903d42d..3e8e2a440 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -36,7 +36,7 @@ internal sealed class JvmBridge : IJvmBridge private readonly ILoggerService _logger = LoggerServiceFactory.GetLogger(typeof(JvmBridge)); private readonly int _portNumber; - private readonly JvmThreadPool _jvmThreadPool; + private readonly JvmThreadPoolGarbageCollector _jvmThreadPool; internal JvmBridge(int portNumber) { @@ -48,7 +48,7 @@ internal JvmBridge(int portNumber) _portNumber = portNumber; _logger.LogInfo($"JvMBridge port is {portNumber}"); - _jvmThreadPool = new JvmThreadPool(this, TimeSpan.FromMinutes(30)); + _jvmThreadPool = new JvmThreadPoolGarbageCollector(this, TimeSpan.FromMinutes(30)); } private ISocketWrapper GetConnection() diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs deleted file mode 100644 index 646cb39fe..000000000 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPool.cs +++ /dev/null @@ -1,83 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Threading; - -namespace Microsoft.Spark.Interop.Ipc -{ - /// - /// This class corresponds to the ThreadPool we maintain on the JVM side. This class keeps - /// track of which .NET threads are still alive, and issues an rmThread command if a thread is - /// not. - /// - internal class JvmThreadPool : IDisposable - { - private readonly IJvmBridge _jvmBridge; - private readonly ConcurrentDictionary _activeThreads; - private readonly Timer _activeThreadMonitor; - - /// - /// Construct the JvmThreadPool. - /// - /// The JvmBridge used to call JVM methods. - /// The interval to GC finished threads. - public JvmThreadPool(IJvmBridge jvmBridge, TimeSpan threadGCInterval) - { - _jvmBridge = jvmBridge; - _activeThreads = new ConcurrentDictionary(); - _activeThreadMonitor = new Timer( - (state) => GarbageCollectThreads(), null, threadGCInterval, threadGCInterval); - } - - /// - /// Dispose of the thread monitor and run a final round of thread GC. - /// - public void Dispose() - { - _activeThreadMonitor.Dispose(); - GarbageCollectThreads(); - } - - /// - /// Try to add a thread to the pool. - /// - /// The thread to add. - /// True if success, false if already added. - public bool TryAddThread(Thread thread) => - _activeThreads.TryAdd(thread.ManagedThreadId, thread); - - /// - /// Try to remove a thread. - /// - /// The ID of the thread to remove. - /// True if success, false if the thread cannot be found. - public bool TryRemoveThread(int managedThreadId) - { - if (_activeThreads.TryRemove(managedThreadId, out _)) - { - _jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", managedThreadId); - return true; - } - - return false; - } - - /// - /// Remove any threads that are no longer active. - /// - private void GarbageCollectThreads() - { - foreach (KeyValuePair kvp in _activeThreads) - { - if (!kvp.Value.IsAlive) - { - TryRemoveThread(kvp.Key); - } - } - } - } -} diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGarbageCollector.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGarbageCollector.cs new file mode 100644 index 000000000..6251aeee3 --- /dev/null +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGarbageCollector.cs @@ -0,0 +1,136 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; + +namespace Microsoft.Spark.Interop.Ipc +{ + /// + /// In .NET for Apache Spark, we maintain a 1-to-1 mapping between .NET application threads + /// and corresponding JVM threads. When a .NET thread calls a Spark API, that call is executed + /// by its corresponding JVM thread. This functionality allows for multithreaded applications + /// with thread-local variables. + /// + /// This class keeps track of the .NET application thread lifecycle. When a .NET application + /// thread is no longer alive, this class submits an rmThread command to the JVM backend to + /// dispose of its corresponding JVM thread. + /// + internal class JvmThreadPoolGarbageCollector : IDisposable + { + private readonly IJvmBridge _jvmBridge; + private readonly TimeSpan _threadGCInterval; + private readonly ConcurrentDictionary _activeThreads; + + private readonly object _activeThreadGCTimerLock; + private Timer _activeThreadGCTimer; + + /// + /// Construct the JvmThreadPoolGarbageCollector. + /// + /// The JvmBridge used to call JVM methods. + /// The interval to GC finished threads. + public JvmThreadPoolGarbageCollector(IJvmBridge jvmBridge, TimeSpan threadGCInterval) + { + _jvmBridge = jvmBridge; + _threadGCInterval = threadGCInterval; + _activeThreads = new ConcurrentDictionary(); + + _activeThreadGCTimerLock = new object(); + _activeThreadGCTimer = null; + } + + /// + /// Dispose of the GC timer and run a final round of thread GC. + /// + public void Dispose() + { + lock (_activeThreadGCTimerLock) + { + if (_activeThreadGCTimer != null) + { + _activeThreadGCTimer.Dispose(); + _activeThreadGCTimer = null; + } + } + + GarbageCollectThreads(); + } + + /// + /// Try to start monitoring a thread. + /// + /// The thread to add. + /// True if success, false if already added. + public bool TryAddThread(Thread thread) + { + bool returnValue = _activeThreads.TryAdd(thread.ManagedThreadId, thread); + + // Initialize the GC timer if necessary. + if (_activeThreadGCTimer == null) + { + lock (_activeThreadGCTimerLock) + { + if (_activeThreadGCTimer == null && _activeThreads.Count > 0) + { + _activeThreadGCTimer = new Timer( + (state) => GarbageCollectThreads(), + null, + _threadGCInterval, + _threadGCInterval); + } + } + } + + return returnValue; + } + + /// + /// Try to remove a thread from the pool. If the removal is successful, then the + /// corresponding JVM thread will also be disposed. + /// + /// The ID of the thread to remove. + /// True if success, false if the thread cannot be found. + private bool TryRemoveAndDisposeThread(int threadId) + { + if (_activeThreads.TryRemove(threadId, out _)) + { + // _activeThreads does not have ownership of the threads on the .NET side. This + // class does not need to call Join() on the .NET Thread. However, this class is + // responsible for sending the rmThread command to the JVM to trigger disposal + // of the corresponding JVM thread. + _jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId); + return true; + } + + return false; + } + + /// + /// Remove any threads that are no longer active. + /// + private void GarbageCollectThreads() + { + foreach (KeyValuePair kvp in _activeThreads) + { + if (!kvp.Value.IsAlive) + { + TryRemoveAndDisposeThread(kvp.Key); + } + } + + lock (_activeThreadGCTimerLock) + { + // Dispose of the timer if there are no threads to monitor. + if (_activeThreadGCTimer != null && _activeThreads.IsEmpty) + { + _activeThreadGCTimer.Dispose(); + _activeThreadGCTimer = null; + } + } + } + } +} From dadf60fed13ff741df37b330414bd5cd22ed7ace Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 28 Sep 2020 18:18:07 -0700 Subject: [PATCH 35/42] Comment explaining all methods are thread-safe --- .../Interop/Ipc/JvmThreadPoolGarbageCollector.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGarbageCollector.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGarbageCollector.cs index 6251aeee3..ea9b95230 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGarbageCollector.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGarbageCollector.cs @@ -17,7 +17,7 @@ namespace Microsoft.Spark.Interop.Ipc /// /// This class keeps track of the .NET application thread lifecycle. When a .NET application /// thread is no longer alive, this class submits an rmThread command to the JVM backend to - /// dispose of its corresponding JVM thread. + /// dispose of its corresponding JVM thread. All methods are thread-safe. /// internal class JvmThreadPoolGarbageCollector : IDisposable { From cf641b7128c4bde58eb7894c6fd7c1fdb9b2b3ea Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Tue, 29 Sep 2020 12:36:23 -0700 Subject: [PATCH 36/42] Make GC interval configurable --- .../JvmThreadPoolGarbageCollectorTests.cs | 17 +++++++++++++++++ .../Microsoft.Spark/Interop/Ipc/JvmBridge.cs | 3 ++- .../Services/ConfigurationService.cs | 12 ++++++++++++ .../Services/IConfigurationService.cs | 7 +++++++ 4 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGarbageCollectorTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGarbageCollectorTests.cs index 1b2402a33..14d806f23 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGarbageCollectorTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGarbageCollectorTests.cs @@ -4,6 +4,7 @@ using System; using System.Threading; +using Microsoft.Spark.Interop; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Sql; using Xunit; @@ -89,5 +90,21 @@ public void TestThreadRm() thread.Join(); _jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId); } + + /// + /// Test that the JvmThreadGarbageCollectionInterval configuration defaults to 5 minutes, + /// and can be updated correctly by setting the environment variable. + /// + [Fact] + public void TestIntervalConfiguration() + { + // Default value is 5 minutes. + Assert.Null(Environment.GetEnvironmentVariable("DOTNET_THREAD_GC_INTERVAL")); + Assert.Equal(TimeSpan.FromMinutes(5), SparkEnvironment.ConfigurationService.JvmThreadGarbageCollectionInterval); + + // Test a custom value. + Environment.SetEnvironmentVariable("DOTNET_THREAD_GC_INTERVAL", "1:30:00"); + Assert.Equal(TimeSpan.FromMinutes(90), SparkEnvironment.ConfigurationService.JvmThreadGarbageCollectionInterval); + } } } diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index 3e8e2a440..f3278c12c 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -48,7 +48,8 @@ internal JvmBridge(int portNumber) _portNumber = portNumber; _logger.LogInfo($"JvMBridge port is {portNumber}"); - _jvmThreadPool = new JvmThreadPoolGarbageCollector(this, TimeSpan.FromMinutes(30)); + _jvmThreadPool = new JvmThreadPoolGarbageCollector( + this, SparkEnvironment.ConfigurationService.JvmThreadGarbageCollectionInterval); } private ISocketWrapper GetConnection() diff --git a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs index 3b7de1555..2a0510082 100644 --- a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs +++ b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs @@ -33,6 +33,18 @@ internal sealed class ConfigurationService : IConfigurationService private string _workerPath; + /// + /// How often to run GC on JVM ThreadPool threads. + /// + public TimeSpan JvmThreadGarbageCollectionInterval + { + get + { + string envVar = Environment.GetEnvironmentVariable("DOTNET_THREAD_GC_INTERVAL"); + return string.IsNullOrEmpty(envVar) ? TimeSpan.FromMinutes(5) : TimeSpan.Parse(envVar); + } + } + /// /// Returns the port number for socket communication between JVM and CLR. /// diff --git a/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs b/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs index 5c7a4074f..43b29793d 100644 --- a/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs +++ b/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; + namespace Microsoft.Spark.Services { /// @@ -9,6 +11,11 @@ namespace Microsoft.Spark.Services /// internal interface IConfigurationService { + /// + /// How often to run GC on JVM ThreadPool threads. Default value is 5 minutes. + /// + TimeSpan JvmThreadGarbageCollectionInterval { get; } + /// /// The port number used for communicating with the .NET backend process. /// From 5ebabb1ab0bad535bc806fd1922dce97b435d62d Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Tue, 29 Sep 2020 15:14:59 -0700 Subject: [PATCH 37/42] Catch throwable --- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 2 +- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 2 +- .../org/apache/spark/api/dotnet/DotnetBackendHandler.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index da9839d56..adb9dac00 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -176,7 +176,7 @@ class DotnetBackendHandler(server: DotnetBackend) "invalid method " + methodName + " for object " + objId) } } catch { - case e: Exception => + case e: Throwable => val jvmObj = JVMObjectTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index da9839d56..adb9dac00 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -176,7 +176,7 @@ class DotnetBackendHandler(server: DotnetBackend) "invalid method " + methodName + " for object " + objId) } } catch { - case e: Exception => + case e: Throwable => val jvmObj = JVMObjectTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index da9839d56..adb9dac00 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -176,7 +176,7 @@ class DotnetBackendHandler(server: DotnetBackend) "invalid method " + methodName + " for object " + objId) } } catch { - case e: Exception => + case e: Throwable => val jvmObj = JVMObjectTracker.get(objId) val jvmObjName = jvmObj match { case Some(jObj) => jObj.getClass.getName From 70404d44c587a563f493baff96b5cf500e437dc0 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Sat, 3 Oct 2020 11:48:49 -0700 Subject: [PATCH 38/42] Address some of Terry's comments. --- ...lectorTests.cs => JvmThreadPoolGCTests.cs} | 26 +++++++++++-------- .../Microsoft.Spark/Interop/Ipc/JvmBridge.cs | 10 +++---- ...GarbageCollector.cs => JvmThreadPoolGC.cs} | 12 ++++----- .../Services/ConfigurationService.cs | 4 +-- .../Services/IConfigurationService.cs | 4 +-- 5 files changed, 30 insertions(+), 26 deletions(-) rename src/csharp/Microsoft.Spark.E2ETest/IpcTests/{JvmThreadPoolGarbageCollectorTests.cs => JvmThreadPoolGCTests.cs} (78%) rename src/csharp/Microsoft.Spark/Interop/Ipc/{JvmThreadPoolGarbageCollector.cs => JvmThreadPoolGC.cs} (93%) diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGarbageCollectorTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs similarity index 78% rename from src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGarbageCollectorTests.cs rename to src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs index 14d806f23..5ee3a1afe 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGarbageCollectorTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs @@ -12,12 +12,12 @@ namespace Microsoft.Spark.E2ETest.IpcTests { [Collection("Spark E2E Tests")] - public class JvmThreadPoolGarbageCollectorTests + public class JvmThreadPoolGCTests { private readonly SparkSession _spark; private readonly IJvmBridge _jvmBridge; - public JvmThreadPoolGarbageCollectorTests(SparkFixture fixture) + public JvmThreadPoolGCTests(SparkFixture fixture) { _spark = fixture.Spark; _jvmBridge = ((IJvmObjectReferenceProvider)_spark).Reference.Jvm; @@ -42,7 +42,7 @@ void testChildThread(string appName) // Since we are in the child thread, GetActiveSession() should return the child // SparkSession. - var activeSession = SparkSession.GetActiveSession(); + SparkSession activeSession = SparkSession.GetActiveSession(); Assert.NotNull(activeSession); Assert.Equal(appName, activeSession.Conf().Get("spark.app.name", null)); }); @@ -51,7 +51,7 @@ void testChildThread(string appName) thread.Join(); } - for (var i = 0; i < 5; i++) + for (var i = 0; i < 5; ++i) { testChildThread(i.ToString()); } @@ -60,12 +60,12 @@ void testChildThread(string appName) } /// - /// Monitor a thread via the JvmThreadPoolGarbageCollector. + /// Monitor a thread via the JvmThreadPoolGC. /// [Fact] - public void TestMonitorThread() + public void TestTryAddThread() { - var threadPool = new JvmThreadPoolGarbageCollector(_jvmBridge, TimeSpan.FromMinutes(30)); + using var threadPool = new JvmThreadPoolGC(_jvmBridge, TimeSpan.FromMinutes(30)); var thread = new Thread(() => _spark.Sql("SELECT TRUE")); thread.Start(); @@ -92,19 +92,23 @@ public void TestThreadRm() } /// - /// Test that the JvmThreadGarbageCollectionInterval configuration defaults to 5 minutes, - /// and can be updated correctly by setting the environment variable. + /// Test that the GC interval configuration defaults to 5 minutes, and can be updated + /// correctly by setting the environment variable. /// [Fact] public void TestIntervalConfiguration() { // Default value is 5 minutes. Assert.Null(Environment.GetEnvironmentVariable("DOTNET_THREAD_GC_INTERVAL")); - Assert.Equal(TimeSpan.FromMinutes(5), SparkEnvironment.ConfigurationService.JvmThreadGarbageCollectionInterval); + Assert.Equal( + TimeSpan.FromMinutes(5), + SparkEnvironment.ConfigurationService.JvmThreadGCInterval); // Test a custom value. Environment.SetEnvironmentVariable("DOTNET_THREAD_GC_INTERVAL", "1:30:00"); - Assert.Equal(TimeSpan.FromMinutes(90), SparkEnvironment.ConfigurationService.JvmThreadGarbageCollectionInterval); + Assert.Equal( + TimeSpan.FromMinutes(90), + SparkEnvironment.ConfigurationService.JvmThreadGCInterval); } } } diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index f3278c12c..ecf4fe5c2 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -36,7 +36,7 @@ internal sealed class JvmBridge : IJvmBridge private readonly ILoggerService _logger = LoggerServiceFactory.GetLogger(typeof(JvmBridge)); private readonly int _portNumber; - private readonly JvmThreadPoolGarbageCollector _jvmThreadPool; + private readonly JvmThreadPoolGC _jvmThreadPoolGC; internal JvmBridge(int portNumber) { @@ -48,8 +48,8 @@ internal JvmBridge(int portNumber) _portNumber = portNumber; _logger.LogInfo($"JvMBridge port is {portNumber}"); - _jvmThreadPool = new JvmThreadPoolGarbageCollector( - this, SparkEnvironment.ConfigurationService.JvmThreadGarbageCollectionInterval); + _jvmThreadPoolGC = new JvmThreadPoolGC( + this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval); } private ISocketWrapper GetConnection() @@ -183,7 +183,7 @@ private object CallJavaMethod( (int)payloadMemoryStream.Position); outputStream.Flush(); - _jvmThreadPool.TryAddThread(thread); + _jvmThreadPoolGC.TryAddThread(thread); Stream inputStream = socket.InputStream; int isMethodCallFailed = SerDe.ReadInt32(inputStream); @@ -419,7 +419,7 @@ private object ReadCollection(Stream s) public void Dispose() { - _jvmThreadPool.Dispose(); + _jvmThreadPoolGC.Dispose(); while (_sockets.TryDequeue(out ISocketWrapper socket)) { if (socket != null) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGarbageCollector.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs similarity index 93% rename from src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGarbageCollector.cs rename to src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs index ea9b95230..c339fd93b 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGarbageCollector.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs @@ -19,7 +19,7 @@ namespace Microsoft.Spark.Interop.Ipc /// thread is no longer alive, this class submits an rmThread command to the JVM backend to /// dispose of its corresponding JVM thread. All methods are thread-safe. /// - internal class JvmThreadPoolGarbageCollector : IDisposable + internal class JvmThreadPoolGC : IDisposable { private readonly IJvmBridge _jvmBridge; private readonly TimeSpan _threadGCInterval; @@ -29,11 +29,11 @@ internal class JvmThreadPoolGarbageCollector : IDisposable private Timer _activeThreadGCTimer; /// - /// Construct the JvmThreadPoolGarbageCollector. + /// Construct the JvmThreadPoolGC. /// /// The JvmBridge used to call JVM methods. /// The interval to GC finished threads. - public JvmThreadPoolGarbageCollector(IJvmBridge jvmBridge, TimeSpan threadGCInterval) + public JvmThreadPoolGC(IJvmBridge jvmBridge, TimeSpan threadGCInterval) { _jvmBridge = jvmBridge; _threadGCInterval = threadGCInterval; @@ -57,7 +57,7 @@ public void Dispose() } } - GarbageCollectThreads(); + GCThreads(); } /// @@ -77,7 +77,7 @@ public bool TryAddThread(Thread thread) if (_activeThreadGCTimer == null && _activeThreads.Count > 0) { _activeThreadGCTimer = new Timer( - (state) => GarbageCollectThreads(), + (state) => GCThreads(), null, _threadGCInterval, _threadGCInterval); @@ -112,7 +112,7 @@ private bool TryRemoveAndDisposeThread(int threadId) /// /// Remove any threads that are no longer active. /// - private void GarbageCollectThreads() + private void GCThreads() { foreach (KeyValuePair kvp in _activeThreads) { diff --git a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs index 2a0510082..364bece47 100644 --- a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs +++ b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs @@ -34,9 +34,9 @@ internal sealed class ConfigurationService : IConfigurationService private string _workerPath; /// - /// How often to run GC on JVM ThreadPool threads. + /// How often to run GC on JVM ThreadPool threads. Defaults to 5 minutes. /// - public TimeSpan JvmThreadGarbageCollectionInterval + public TimeSpan JvmThreadGCInterval { get { diff --git a/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs b/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs index 43b29793d..5398632bd 100644 --- a/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs +++ b/src/csharp/Microsoft.Spark/Services/IConfigurationService.cs @@ -12,9 +12,9 @@ namespace Microsoft.Spark.Services internal interface IConfigurationService { /// - /// How often to run GC on JVM ThreadPool threads. Default value is 5 minutes. + /// How often to run GC on JVM ThreadPool threads. /// - TimeSpan JvmThreadGarbageCollectionInterval { get; } + TimeSpan JvmThreadGCInterval { get; } /// /// The port number used for communicating with the .NET backend process. From 7c585c8bc2da2f3fb2c0e4840756604d572e8efa Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Sat, 3 Oct 2020 12:06:40 -0700 Subject: [PATCH 39/42] Fix test name --- .../Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs index 5ee3a1afe..f0c54d6c3 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs @@ -82,7 +82,7 @@ public void TestTryAddThread() /// the JvmBridge. /// [Fact] - public void TestThreadRm() + public void TestRmThread() { // Create a thread and ensure that it is initialized in the JVM ThreadPool. var thread = new Thread(() => _spark.Sql("SELECT TRUE")); From 80fa44bb612599e6553b734ebefca2a9d3c4dfe7 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Sun, 4 Oct 2020 15:53:40 -0700 Subject: [PATCH 40/42] Deleting thread returns bool --- .../spark/api/dotnet/DotnetBackendHandler.scala | 10 +++++----- .../org/apache/spark/api/dotnet/ThreadPool.scala | 12 +++++++++--- .../spark/api/dotnet/DotnetBackendHandler.scala | 10 +++++----- .../org/apache/spark/api/dotnet/ThreadPool.scala | 12 +++++++++--- .../spark/api/dotnet/DotnetBackendHandler.scala | 11 ++++------- .../org/apache/spark/api/dotnet/ThreadPool.scala | 12 +++++++++--- 6 files changed, 41 insertions(+), 26 deletions(-) diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index adb9dac00..b22921a60 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.collection.mutable.HashMap +import scala.language.existentials + import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -import scala.collection.mutable.HashMap -import scala.language.existentials - /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. @@ -70,8 +70,8 @@ class DotnetBackendHandler(server: DotnetBackend) try { assert(readObjectType(dis) == 'i') val threadToDelete = readInt(dis) - ThreadPool.deleteThread(threadToDelete) - writeInt(dos, 0) + val result = ThreadPool.tryDeleteThread(threadToDelete) + writeBoolean(dos, result) writeObject(dos, null) } catch { case e: Exception => diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 6b1d7888c..1888ec746 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -38,12 +38,18 @@ object ThreadPool { } /** - * Delete a particular thread. + * Try to delete a particular thread. * * @param threadId Integer id of the thread. + * @return True if successful, false if thread does not exist. */ - def deleteThread(threadId: Int): Unit = synchronized { - executors.remove(threadId).foreach(_.shutdown) + def tryDeleteThread(threadId: Int): Boolean = synchronized { + executors.remove(threadId) match { + case Some(executorService) => + executorService.shutdown() + true + case None => false + } } /** diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index adb9dac00..b22921a60 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,14 +8,14 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.collection.mutable.HashMap +import scala.language.existentials + import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.api.dotnet.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -import scala.collection.mutable.HashMap -import scala.language.existentials - /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. @@ -70,8 +70,8 @@ class DotnetBackendHandler(server: DotnetBackend) try { assert(readObjectType(dis) == 'i') val threadToDelete = readInt(dis) - ThreadPool.deleteThread(threadToDelete) - writeInt(dos, 0) + val result = ThreadPool.tryDeleteThread(threadToDelete) + writeBoolean(dos, result) writeObject(dos, null) } catch { case e: Exception => diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 6b1d7888c..1888ec746 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -38,12 +38,18 @@ object ThreadPool { } /** - * Delete a particular thread. + * Try to delete a particular thread. * * @param threadId Integer id of the thread. + * @return True if successful, false if thread does not exist. */ - def deleteThread(threadId: Int): Unit = synchronized { - executors.remove(threadId).foreach(_.shutdown) + def tryDeleteThread(threadId: Int): Boolean = synchronized { + executors.remove(threadId) match { + case Some(executorService) => + executorService.shutdown() + true + case None => false + } } /** diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index adb9dac00..b2a83e809 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,14 +8,11 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.api.dotnet.SerDe._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils - import scala.collection.mutable.HashMap import scala.language.existentials +import org.apache.spark.api.dotnet.SerDe._ + /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. @@ -70,8 +67,8 @@ class DotnetBackendHandler(server: DotnetBackend) try { assert(readObjectType(dis) == 'i') val threadToDelete = readInt(dis) - ThreadPool.deleteThread(threadToDelete) - writeInt(dos, 0) + val result = ThreadPool.tryDeleteThread(threadToDelete) + writeBoolean(dos, result) writeObject(dos, null) } catch { case e: Exception => diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala index 6b1d7888c..1888ec746 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -38,12 +38,18 @@ object ThreadPool { } /** - * Delete a particular thread. + * Try to delete a particular thread. * * @param threadId Integer id of the thread. + * @return True if successful, false if thread does not exist. */ - def deleteThread(threadId: Int): Unit = synchronized { - executors.remove(threadId).foreach(_.shutdown) + def tryDeleteThread(threadId: Int): Boolean = synchronized { + executors.remove(threadId) match { + case Some(executorService) => + executorService.shutdown() + true + case None => false + } } /** From 1529a27a1af1822182e1d258618b9df8ed0e8d12 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 5 Oct 2020 12:02:48 -0700 Subject: [PATCH 41/42] Address comments --- .../IpcTests/JvmThreadPoolGCTests.cs | 17 ++++++++---- .../Microsoft.Spark/Interop/Ipc/JvmBridge.cs | 2 +- .../Interop/Ipc/JvmThreadPoolGC.cs | 26 +++++++++++++++---- .../Services/ConfigurationService.cs | 2 +- .../api/dotnet/DotnetBackendHandler.scala | 4 +-- .../api/dotnet/DotnetBackendHandler.scala | 4 +-- .../api/dotnet/DotnetBackendHandler.scala | 11 +++++--- 7 files changed, 46 insertions(+), 20 deletions(-) diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs index f0c54d6c3..8e59694ac 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/JvmThreadPoolGCTests.cs @@ -6,6 +6,7 @@ using System.Threading; using Microsoft.Spark.Interop; using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Services; using Microsoft.Spark.Sql; using Xunit; @@ -14,11 +15,13 @@ namespace Microsoft.Spark.E2ETest.IpcTests [Collection("Spark E2E Tests")] public class JvmThreadPoolGCTests { + private readonly ILoggerService _loggerService; private readonly SparkSession _spark; private readonly IJvmBridge _jvmBridge; public JvmThreadPoolGCTests(SparkFixture fixture) { + _loggerService = LoggerServiceFactory.GetLogger(typeof(JvmThreadPoolGCTests)); _spark = fixture.Spark; _jvmBridge = ((IJvmObjectReferenceProvider)_spark).Reference.Jvm; } @@ -51,7 +54,7 @@ void testChildThread(string appName) thread.Join(); } - for (var i = 0; i < 5; ++i) + for (int i = 0; i < 5; ++i) { testChildThread(i.ToString()); } @@ -65,7 +68,8 @@ void testChildThread(string appName) [Fact] public void TestTryAddThread() { - using var threadPool = new JvmThreadPoolGC(_jvmBridge, TimeSpan.FromMinutes(30)); + using var threadPool = new JvmThreadPoolGC( + _loggerService, _jvmBridge, TimeSpan.FromMinutes(30)); var thread = new Thread(() => _spark.Sql("SELECT TRUE")); thread.Start(); @@ -88,7 +92,10 @@ public void TestRmThread() var thread = new Thread(() => _spark.Sql("SELECT TRUE")); thread.Start(); thread.Join(); - _jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId); + + // First call should return true. Second call should return false. + Assert.True((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId)); + Assert.False((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId)); } /// @@ -99,13 +106,13 @@ public void TestRmThread() public void TestIntervalConfiguration() { // Default value is 5 minutes. - Assert.Null(Environment.GetEnvironmentVariable("DOTNET_THREAD_GC_INTERVAL")); + Assert.Null(Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL")); Assert.Equal( TimeSpan.FromMinutes(5), SparkEnvironment.ConfigurationService.JvmThreadGCInterval); // Test a custom value. - Environment.SetEnvironmentVariable("DOTNET_THREAD_GC_INTERVAL", "1:30:00"); + Environment.SetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL", "1:30:00"); Assert.Equal( TimeSpan.FromMinutes(90), SparkEnvironment.ConfigurationService.JvmThreadGCInterval); diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index ecf4fe5c2..1dc53ef13 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -49,7 +49,7 @@ internal JvmBridge(int portNumber) _logger.LogInfo($"JvMBridge port is {portNumber}"); _jvmThreadPoolGC = new JvmThreadPoolGC( - this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval); + _logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval); } private ISocketWrapper GetConnection() diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs index c339fd93b..841b99189 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs @@ -6,6 +6,7 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Threading; +using Microsoft.Spark.Services; namespace Microsoft.Spark.Interop.Ipc { @@ -21,6 +22,7 @@ namespace Microsoft.Spark.Interop.Ipc /// internal class JvmThreadPoolGC : IDisposable { + private readonly ILoggerService _loggerService; private readonly IJvmBridge _jvmBridge; private readonly TimeSpan _threadGCInterval; private readonly ConcurrentDictionary _activeThreads; @@ -31,10 +33,12 @@ internal class JvmThreadPoolGC : IDisposable /// /// Construct the JvmThreadPoolGC. /// + /// Logger service. /// The JvmBridge used to call JVM methods. /// The interval to GC finished threads. - public JvmThreadPoolGC(IJvmBridge jvmBridge, TimeSpan threadGCInterval) + public JvmThreadPoolGC(ILoggerService loggerService, IJvmBridge jvmBridge, TimeSpan threadGCInterval) { + _loggerService = loggerService; _jvmBridge = jvmBridge; _threadGCInterval = threadGCInterval; _activeThreads = new ConcurrentDictionary(); @@ -94,7 +98,7 @@ public bool TryAddThread(Thread thread) /// /// The ID of the thread to remove. /// True if success, false if the thread cannot be found. - private bool TryRemoveAndDisposeThread(int threadId) + private bool TryDisposeJvmThread(int threadId) { if (_activeThreads.TryRemove(threadId, out _)) { @@ -102,8 +106,17 @@ private bool TryRemoveAndDisposeThread(int threadId) // class does not need to call Join() on the .NET Thread. However, this class is // responsible for sending the rmThread command to the JVM to trigger disposal // of the corresponding JVM thread. - _jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId); - return true; + if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId)) + { + _loggerService.LogDebug($"GC'd JVM thread {threadId}."); + } + else + { + _loggerService.LogWarn( + $"rmThread returned false for JVM thread {threadId}. " + + $"Either thread does not exist or has already been GC'd."); + return false; + } } return false; @@ -114,11 +127,12 @@ private bool TryRemoveAndDisposeThread(int threadId) /// private void GCThreads() { + _loggerService.LogDebug("Starting JVM thread GC."); foreach (KeyValuePair kvp in _activeThreads) { if (!kvp.Value.IsAlive) { - TryRemoveAndDisposeThread(kvp.Key); + TryDisposeJvmThread(kvp.Key); } } @@ -131,6 +145,8 @@ private void GCThreads() _activeThreadGCTimer = null; } } + + _loggerService.LogDebug("JVM thread GC complete."); } } } diff --git a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs index 364bece47..4ce565c84 100644 --- a/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs +++ b/src/csharp/Microsoft.Spark/Services/ConfigurationService.cs @@ -40,7 +40,7 @@ public TimeSpan JvmThreadGCInterval { get { - string envVar = Environment.GetEnvironmentVariable("DOTNET_THREAD_GC_INTERVAL"); + string envVar = Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL"); return string.IsNullOrEmpty(envVar) ? TimeSpan.FromMinutes(5) : TimeSpan.Parse(envVar); } } diff --git a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index b22921a60..e632589e4 100644 --- a/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -71,8 +71,8 @@ class DotnetBackendHandler(server: DotnetBackend) assert(readObjectType(dis) == 'i') val threadToDelete = readInt(dis) val result = ThreadPool.tryDeleteThread(threadToDelete) - writeBoolean(dos, result) - writeObject(dos, null) + writeInt(dos, 0) + writeObject(dos, result.asInstanceOf[AnyRef]) } catch { case e: Exception => logError(s"Removing thread $threadId failed", e) diff --git a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index b22921a60..e632589e4 100644 --- a/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -71,8 +71,8 @@ class DotnetBackendHandler(server: DotnetBackend) assert(readObjectType(dis) == 'i') val threadToDelete = readInt(dis) val result = ThreadPool.tryDeleteThread(threadToDelete) - writeBoolean(dos, result) - writeObject(dos, null) + writeInt(dos, 0) + writeObject(dos, result.asInstanceOf[AnyRef]) } catch { case e: Exception => logError(s"Removing thread $threadId failed", e) diff --git a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala index b2a83e809..1446e5ff6 100644 --- a/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ b/src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -8,11 +8,14 @@ package org.apache.spark.api.dotnet import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} +import org.apache.spark.api.dotnet.SerDe._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + import scala.collection.mutable.HashMap import scala.language.existentials -import org.apache.spark.api.dotnet.SerDe._ - /** * Handler for DotnetBackend. * This implementation is similar to RBackendHandler. @@ -68,8 +71,8 @@ class DotnetBackendHandler(server: DotnetBackend) assert(readObjectType(dis) == 'i') val threadToDelete = readInt(dis) val result = ThreadPool.tryDeleteThread(threadToDelete) - writeBoolean(dos, result) - writeObject(dos, null) + writeInt(dos, 0) + writeObject(dos, result.asInstanceOf[AnyRef]) } catch { case e: Exception => logError(s"Removing thread $threadId failed", e) From 789a5559d1ed21b4bbe0f6ad53f35415674eea27 Mon Sep 17 00:00:00 2001 From: Andrew Fogarty Date: Mon, 5 Oct 2020 15:52:19 -0700 Subject: [PATCH 42/42] Fix return value and clean up logs --- src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs index 841b99189..0eacebadd 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmThreadPoolGC.cs @@ -109,13 +109,13 @@ private bool TryDisposeJvmThread(int threadId) if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId)) { _loggerService.LogDebug($"GC'd JVM thread {threadId}."); + return true; } else { _loggerService.LogWarn( $"rmThread returned false for JVM thread {threadId}. " + $"Either thread does not exist or has already been GC'd."); - return false; } } @@ -127,7 +127,6 @@ private bool TryDisposeJvmThread(int threadId) /// private void GCThreads() { - _loggerService.LogDebug("Starting JVM thread GC."); foreach (KeyValuePair kvp in _activeThreads) { if (!kvp.Value.IsAlive) @@ -145,8 +144,6 @@ private void GCThreads() _activeThreadGCTimer = null; } } - - _loggerService.LogDebug("JVM thread GC complete."); } } }