Skip to content

Commit 1767c3e

Browse files
authored
Fix for memory leak in JVMObjectTracker (#801)
1 parent 2346553 commit 1767c3e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2213
-542
lines changed

src/csharp/Microsoft.Spark.UnitTest/CallbackTests.cs

+11
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,17 @@ public void TestCallbackHandlers()
138138
Assert.Empty(callbackHandler.Inputs);
139139
}
140140
}
141+
142+
[Fact]
143+
public void TestJvmCallbackClientProperty()
144+
{
145+
var server = new CallbackServer(_mockJvm.Object, run: false);
146+
Assert.Throws<InvalidOperationException>(() => server.JvmCallbackClient);
147+
148+
using ISocketWrapper callbackSocket = SocketFactory.CreateSocket();
149+
server.Run(callbackSocket);
150+
Assert.NotNull(server.JvmCallbackClient);
151+
}
141152

142153
private void TestCallbackConnection(
143154
ConcurrentDictionary<int, ICallbackHandler> callbackHandlersDict,

src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs

+17-1
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,25 @@ internal sealed class CallbackServer
6464
private bool _isRunning = false;
6565

6666
private ISocketWrapper _listener;
67+
68+
private JvmObjectReference _jvmCallbackClient;
6769

6870
internal int CurrentNumConnections => _connections.Count;
6971

72+
internal JvmObjectReference JvmCallbackClient
73+
{
74+
get
75+
{
76+
if (_jvmCallbackClient is null)
77+
{
78+
throw new InvalidOperationException(
79+
"Please make sure that CallbackServer was started before accessing JvmCallbackClient.");
80+
}
81+
82+
return _jvmCallbackClient;
83+
}
84+
}
85+
7086
internal CallbackServer(IJvmBridge jvm, bool run = true)
7187
{
7288
AppDomain.CurrentDomain.ProcessExit += (s, e) => Shutdown();
@@ -113,7 +129,7 @@ internal void Run(ISocketWrapper listener)
113129

114130
// Communicate with the JVM the callback server's address and port.
115131
var localEndPoint = (IPEndPoint)_listener.LocalEndPoint;
116-
_jvm.CallStaticJavaMethod(
132+
_jvmCallbackClient = (JvmObjectReference)_jvm.CallStaticJavaMethod(
117133
"DotnetHandler",
118134
"connectCallback",
119135
localEndPoint.Address.ToString(),

src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs

+1
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ public DataStreamWriter ForeachBatch(Action<DataFrame, long> func)
228228
_jvmObject.Jvm.CallStaticJavaMethod(
229229
"org.apache.spark.sql.api.dotnet.DotnetForeachBatchHelper",
230230
"callForeachBatch",
231+
SparkEnvironment.CallbackServer.JvmCallbackClient,
231232
this,
232233
callbackId);
233234
return this;

src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ import scala.collection.mutable.Queue
2020
* @param address The address of the Dotnet CallbackServer
2121
* @param port The port of the Dotnet CallbackServer
2222
*/
23-
class CallbackClient(address: String, port: Int) extends Logging {
23+
class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging {
2424
private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]()
2525

2626
private[this] var isShutdown: Boolean = false
2727

28-
final def send(callbackId: Int, writeBody: DataOutputStream => Unit): Unit =
28+
final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit =
2929
getOrCreateConnection() match {
3030
case Some(connection) =>
3131
try {
@@ -50,7 +50,7 @@ class CallbackClient(address: String, port: Int) extends Logging {
5050
return Some(connectionPool.dequeue())
5151
}
5252

53-
Some(new CallbackConnection(address, port))
53+
Some(new CallbackConnection(serDe, address, port))
5454
}
5555

5656
private def addConnection(connection: CallbackConnection): Unit = synchronized {

src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala

+11-11
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,23 @@ import org.apache.spark.internal.Logging
1818
* @param address The address of the Dotnet CallbackServer
1919
* @param port The port of the Dotnet CallbackServer
2020
*/
21-
class CallbackConnection(address: String, port: Int) extends Logging {
21+
class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging {
2222
private[this] val socket: Socket = new Socket(address, port)
2323
private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream)
2424
private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream)
2525

2626
def send(
2727
callbackId: Int,
28-
writeBody: DataOutputStream => Unit): Unit = {
28+
writeBody: (DataOutputStream, SerDe) => Unit): Unit = {
2929
logInfo(s"Calling callback [callback id = $callbackId] ...")
3030

3131
try {
32-
SerDe.writeInt(outputStream, CallbackFlags.CALLBACK)
33-
SerDe.writeInt(outputStream, callbackId)
32+
serDe.writeInt(outputStream, CallbackFlags.CALLBACK)
33+
serDe.writeInt(outputStream, callbackId)
3434

3535
val byteArrayOutputStream = new ByteArrayOutputStream()
36-
writeBody(new DataOutputStream(byteArrayOutputStream))
37-
SerDe.writeInt(outputStream, byteArrayOutputStream.size)
36+
writeBody(new DataOutputStream(byteArrayOutputStream), serDe)
37+
serDe.writeInt(outputStream, byteArrayOutputStream.size)
3838
byteArrayOutputStream.writeTo(outputStream);
3939
} catch {
4040
case e: Exception => {
@@ -44,7 +44,7 @@ class CallbackConnection(address: String, port: Int) extends Logging {
4444

4545
logInfo(s"Signaling END_OF_STREAM.")
4646
try {
47-
SerDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM)
47+
serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM)
4848
outputStream.flush()
4949

5050
val endOfStreamResponse = readFlag(inputStream)
@@ -65,7 +65,7 @@ class CallbackConnection(address: String, port: Int) extends Logging {
6565

6666
def close(): Unit = {
6767
try {
68-
SerDe.writeInt(outputStream, CallbackFlags.CLOSE)
68+
serDe.writeInt(outputStream, CallbackFlags.CLOSE)
6969
outputStream.flush()
7070
} catch {
7171
case e: Exception => logInfo("Unable to send close to .NET callback server.", e)
@@ -95,9 +95,9 @@ class CallbackConnection(address: String, port: Int) extends Logging {
9595
}
9696

9797
private def readFlag(inputStream: DataInputStream): Int = {
98-
val callbackFlag = SerDe.readInt(inputStream)
98+
val callbackFlag = serDe.readInt(inputStream)
9999
if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) {
100-
val exceptionMessage = SerDe.readString(inputStream)
100+
val exceptionMessage = serDe.readString(inputStream)
101101
throw new DotnetException(exceptionMessage)
102102
}
103103
callbackFlag
@@ -109,4 +109,4 @@ class CallbackConnection(address: String, port: Int) extends Logging {
109109
val DOTNET_EXCEPTION_THROWN: Int = -3
110110
val END_OF_STREAM: Int = -4
111111
}
112-
}
112+
}

src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala

+25-22
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class DotnetBackend extends Logging {
3030
private[this] var channelFuture: ChannelFuture = _
3131
private[this] var bootstrap: ServerBootstrap = _
3232
private[this] var bossGroup: EventLoopGroup = _
33+
private[this] val objectTracker = new JVMObjectTracker
34+
35+
@volatile
36+
private[dotnet] var callbackClient: Option[CallbackClient] = None
3337

3438
def init(portNumber: Int): Int = {
3539
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
@@ -55,7 +59,7 @@ class DotnetBackend extends Logging {
5559
// initialBytesToStrip = 4, i.e. strip out the length field itself
5660
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
5761
.addLast("decoder", new ByteArrayDecoder())
58-
.addLast("handler", new DotnetBackendHandler(self))
62+
.addLast("handler", new DotnetBackendHandler(self, objectTracker))
5963
}
6064
})
6165

@@ -64,6 +68,23 @@ class DotnetBackend extends Logging {
6468
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort
6569
}
6670

71+
private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized {
72+
callbackClient = callbackClient match {
73+
case Some(_) => throw new Exception("Callback client already set.")
74+
case None =>
75+
logInfo(s"Connecting to a callback server at $address:$port")
76+
Some(new CallbackClient(new SerDe(objectTracker), address, port))
77+
}
78+
}
79+
80+
private[dotnet] def shutdownCallbackClient(): Unit = synchronized {
81+
callbackClient match {
82+
case Some(client) => client.shutdown()
83+
case None => logInfo("Callback server has already been shutdown.")
84+
}
85+
callbackClient = None
86+
}
87+
6788
def run(): Unit = {
6889
channelFuture.channel.closeFuture().syncUninterruptibly()
6990
}
@@ -82,30 +103,12 @@ class DotnetBackend extends Logging {
82103
}
83104
bootstrap = null
84105

106+
objectTracker.clear()
107+
85108
// Send close to .NET callback server.
86-
DotnetBackend.shutdownCallbackClient()
109+
shutdownCallbackClient()
87110

88111
// Shutdown the thread pool whose executors could still be running.
89112
ThreadPool.shutdown()
90113
}
91114
}
92-
93-
object DotnetBackend extends Logging {
94-
@volatile private[spark] var callbackClient: CallbackClient = null
95-
96-
private[spark] def setCallbackClient(address: String, port: Int) = synchronized {
97-
if (DotnetBackend.callbackClient == null) {
98-
logInfo(s"Connecting to a callback server at $address:$port")
99-
DotnetBackend.callbackClient = new CallbackClient(address, port)
100-
} else {
101-
throw new Exception("Callback client already set.")
102-
}
103-
}
104-
105-
private[spark] def shutdownCallbackClient(): Unit = synchronized {
106-
if (callbackClient != null) {
107-
callbackClient.shutdown()
108-
callbackClient = null
109-
}
110-
}
111-
}

0 commit comments

Comments
 (0)