Skip to content

Commit a67ad59

Browse files
authored
Full support for multithreaded applications (#641)
1 parent 7bcd2a5 commit a67ad59

File tree

14 files changed

+618
-43
lines changed

14 files changed

+618
-43
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Threading;
7+
using Microsoft.Spark.Interop;
8+
using Microsoft.Spark.Interop.Ipc;
9+
using Microsoft.Spark.Services;
10+
using Microsoft.Spark.Sql;
11+
using Xunit;
12+
13+
namespace Microsoft.Spark.E2ETest.IpcTests
14+
{
15+
[Collection("Spark E2E Tests")]
16+
public class JvmThreadPoolGCTests
17+
{
18+
private readonly ILoggerService _loggerService;
19+
private readonly SparkSession _spark;
20+
private readonly IJvmBridge _jvmBridge;
21+
22+
public JvmThreadPoolGCTests(SparkFixture fixture)
23+
{
24+
_loggerService = LoggerServiceFactory.GetLogger(typeof(JvmThreadPoolGCTests));
25+
_spark = fixture.Spark;
26+
_jvmBridge = ((IJvmObjectReferenceProvider)_spark).Reference.Jvm;
27+
}
28+
29+
/// <summary>
30+
/// Test that the active SparkSession is thread-specific.
31+
/// </summary>
32+
[Fact]
33+
public void TestThreadLocalSessions()
34+
{
35+
SparkSession.ClearActiveSession();
36+
37+
void testChildThread(string appName)
38+
{
39+
var thread = new Thread(() =>
40+
{
41+
Assert.Null(SparkSession.GetActiveSession());
42+
43+
SparkSession.SetActiveSession(
44+
SparkSession.Builder().AppName(appName).GetOrCreate());
45+
46+
// Since we are in the child thread, GetActiveSession() should return the child
47+
// SparkSession.
48+
SparkSession activeSession = SparkSession.GetActiveSession();
49+
Assert.NotNull(activeSession);
50+
Assert.Equal(appName, activeSession.Conf().Get("spark.app.name", null));
51+
});
52+
53+
thread.Start();
54+
thread.Join();
55+
}
56+
57+
for (int i = 0; i < 5; ++i)
58+
{
59+
testChildThread(i.ToString());
60+
}
61+
62+
Assert.Null(SparkSession.GetActiveSession());
63+
}
64+
65+
/// <summary>
66+
/// Monitor a thread via the JvmThreadPoolGC.
67+
/// </summary>
68+
[Fact]
69+
public void TestTryAddThread()
70+
{
71+
using var threadPool = new JvmThreadPoolGC(
72+
_loggerService, _jvmBridge, TimeSpan.FromMinutes(30));
73+
74+
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
75+
thread.Start();
76+
77+
Assert.True(threadPool.TryAddThread(thread));
78+
// Subsequent call should return false, because the thread has already been added.
79+
Assert.False(threadPool.TryAddThread(thread));
80+
81+
thread.Join();
82+
}
83+
84+
/// <summary>
85+
/// Create a Spark worker thread in the JVM ThreadPool then remove it directly through
86+
/// the JvmBridge.
87+
/// </summary>
88+
[Fact]
89+
public void TestRmThread()
90+
{
91+
// Create a thread and ensure that it is initialized in the JVM ThreadPool.
92+
var thread = new Thread(() => _spark.Sql("SELECT TRUE"));
93+
thread.Start();
94+
thread.Join();
95+
96+
// First call should return true. Second call should return false.
97+
Assert.True((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
98+
Assert.False((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", thread.ManagedThreadId));
99+
}
100+
101+
/// <summary>
102+
/// Test that the GC interval configuration defaults to 5 minutes, and can be updated
103+
/// correctly by setting the environment variable.
104+
/// </summary>
105+
[Fact]
106+
public void TestIntervalConfiguration()
107+
{
108+
// Default value is 5 minutes.
109+
Assert.Null(Environment.GetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL"));
110+
Assert.Equal(
111+
TimeSpan.FromMinutes(5),
112+
SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
113+
114+
// Test a custom value.
115+
Environment.SetEnvironmentVariable("DOTNET_JVM_THREAD_GC_INTERVAL", "1:30:00");
116+
Assert.Equal(
117+
TimeSpan.FromMinutes(90),
118+
SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
119+
}
120+
}
121+
}

src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ public void TestSignaturesV2_3_X()
3535

3636
Assert.IsType<Builder>(SparkSession.Builder());
3737

38+
SparkSession.ClearActiveSession();
39+
SparkSession.SetActiveSession(_spark);
40+
Assert.IsType<SparkSession>(SparkSession.GetActiveSession());
41+
3842
SparkSession.ClearDefaultSession();
3943
SparkSession.SetDefaultSession(_spark);
4044
Assert.IsType<SparkSession>(SparkSession.GetDefaultSession());
@@ -76,7 +80,7 @@ public void TestSignaturesV2_4_X()
7680
/// </summary>
7781
[Fact]
7882
public void TestCreateDataFrame()
79-
{
83+
{
8084
// Calling CreateDataFrame with schema
8185
{
8286
var data = new List<GenericRow>

src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs

+10
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.IO;
99
using System.Net;
1010
using System.Text;
11+
using System.Threading;
1112
using Microsoft.Spark.Network;
1213
using Microsoft.Spark.Services;
1314

@@ -35,6 +36,7 @@ internal sealed class JvmBridge : IJvmBridge
3536
private readonly ILoggerService _logger =
3637
LoggerServiceFactory.GetLogger(typeof(JvmBridge));
3738
private readonly int _portNumber;
39+
private readonly JvmThreadPoolGC _jvmThreadPoolGC;
3840

3941
internal JvmBridge(int portNumber)
4042
{
@@ -45,6 +47,9 @@ internal JvmBridge(int portNumber)
4547

4648
_portNumber = portNumber;
4749
_logger.LogInfo($"JvMBridge port is {portNumber}");
50+
51+
_jvmThreadPoolGC = new JvmThreadPoolGC(
52+
_logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval);
4853
}
4954

5055
private ISocketWrapper GetConnection()
@@ -158,11 +163,13 @@ private object CallJavaMethod(
158163
ISocketWrapper socket = null;
159164
try
160165
{
166+
Thread thread = Thread.CurrentThread;
161167
MemoryStream payloadMemoryStream = s_payloadMemoryStream ??= new MemoryStream();
162168
payloadMemoryStream.Position = 0;
163169
PayloadHelper.BuildPayload(
164170
payloadMemoryStream,
165171
isStatic,
172+
thread.ManagedThreadId,
166173
classNameOrJvmObjectReference,
167174
methodName,
168175
args);
@@ -176,6 +183,8 @@ private object CallJavaMethod(
176183
(int)payloadMemoryStream.Position);
177184
outputStream.Flush();
178185

186+
_jvmThreadPoolGC.TryAddThread(thread);
187+
179188
Stream inputStream = socket.InputStream;
180189
int isMethodCallFailed = SerDe.ReadInt32(inputStream);
181190
if (isMethodCallFailed != 0)
@@ -410,6 +419,7 @@ private object ReadCollection(Stream s)
410419

411420
public void Dispose()
412421
{
422+
_jvmThreadPoolGC.Dispose();
413423
while (_sockets.TryDequeue(out ISocketWrapper socket))
414424
{
415425
if (socket != null)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Concurrent;
7+
using System.Collections.Generic;
8+
using System.Threading;
9+
using Microsoft.Spark.Services;
10+
11+
namespace Microsoft.Spark.Interop.Ipc
12+
{
13+
/// <summary>
14+
/// In .NET for Apache Spark, we maintain a 1-to-1 mapping between .NET application threads
15+
/// and corresponding JVM threads. When a .NET thread calls a Spark API, that call is executed
16+
/// by its corresponding JVM thread. This functionality allows for multithreaded applications
17+
/// with thread-local variables.
18+
///
19+
/// This class keeps track of the .NET application thread lifecycle. When a .NET application
20+
/// thread is no longer alive, this class submits an rmThread command to the JVM backend to
21+
/// dispose of its corresponding JVM thread. All methods are thread-safe.
22+
/// </summary>
23+
internal class JvmThreadPoolGC : IDisposable
24+
{
25+
private readonly ILoggerService _loggerService;
26+
private readonly IJvmBridge _jvmBridge;
27+
private readonly TimeSpan _threadGCInterval;
28+
private readonly ConcurrentDictionary<int, Thread> _activeThreads;
29+
30+
private readonly object _activeThreadGCTimerLock;
31+
private Timer _activeThreadGCTimer;
32+
33+
/// <summary>
34+
/// Construct the JvmThreadPoolGC.
35+
/// </summary>
36+
/// <param name="loggerService">Logger service.</param>
37+
/// <param name="jvmBridge">The JvmBridge used to call JVM methods.</param>
38+
/// <param name="threadGCInterval">The interval to GC finished threads.</param>
39+
public JvmThreadPoolGC(ILoggerService loggerService, IJvmBridge jvmBridge, TimeSpan threadGCInterval)
40+
{
41+
_loggerService = loggerService;
42+
_jvmBridge = jvmBridge;
43+
_threadGCInterval = threadGCInterval;
44+
_activeThreads = new ConcurrentDictionary<int, Thread>();
45+
46+
_activeThreadGCTimerLock = new object();
47+
_activeThreadGCTimer = null;
48+
}
49+
50+
/// <summary>
51+
/// Dispose of the GC timer and run a final round of thread GC.
52+
/// </summary>
53+
public void Dispose()
54+
{
55+
lock (_activeThreadGCTimerLock)
56+
{
57+
if (_activeThreadGCTimer != null)
58+
{
59+
_activeThreadGCTimer.Dispose();
60+
_activeThreadGCTimer = null;
61+
}
62+
}
63+
64+
GCThreads();
65+
}
66+
67+
/// <summary>
68+
/// Try to start monitoring a thread.
69+
/// </summary>
70+
/// <param name="thread">The thread to add.</param>
71+
/// <returns>True if success, false if already added.</returns>
72+
public bool TryAddThread(Thread thread)
73+
{
74+
bool returnValue = _activeThreads.TryAdd(thread.ManagedThreadId, thread);
75+
76+
// Initialize the GC timer if necessary.
77+
if (_activeThreadGCTimer == null)
78+
{
79+
lock (_activeThreadGCTimerLock)
80+
{
81+
if (_activeThreadGCTimer == null && _activeThreads.Count > 0)
82+
{
83+
_activeThreadGCTimer = new Timer(
84+
(state) => GCThreads(),
85+
null,
86+
_threadGCInterval,
87+
_threadGCInterval);
88+
}
89+
}
90+
}
91+
92+
return returnValue;
93+
}
94+
95+
/// <summary>
96+
/// Try to remove a thread from the pool. If the removal is successful, then the
97+
/// corresponding JVM thread will also be disposed.
98+
/// </summary>
99+
/// <param name="threadId">The ID of the thread to remove.</param>
100+
/// <returns>True if success, false if the thread cannot be found.</returns>
101+
private bool TryDisposeJvmThread(int threadId)
102+
{
103+
if (_activeThreads.TryRemove(threadId, out _))
104+
{
105+
// _activeThreads does not have ownership of the threads on the .NET side. This
106+
// class does not need to call Join() on the .NET Thread. However, this class is
107+
// responsible for sending the rmThread command to the JVM to trigger disposal
108+
// of the corresponding JVM thread.
109+
if ((bool)_jvmBridge.CallStaticJavaMethod("DotnetHandler", "rmThread", threadId))
110+
{
111+
_loggerService.LogDebug($"GC'd JVM thread {threadId}.");
112+
return true;
113+
}
114+
else
115+
{
116+
_loggerService.LogWarn(
117+
$"rmThread returned false for JVM thread {threadId}. " +
118+
$"Either thread does not exist or has already been GC'd.");
119+
}
120+
}
121+
122+
return false;
123+
}
124+
125+
/// <summary>
126+
/// Remove any threads that are no longer active.
127+
/// </summary>
128+
private void GCThreads()
129+
{
130+
foreach (KeyValuePair<int, Thread> kvp in _activeThreads)
131+
{
132+
if (!kvp.Value.IsAlive)
133+
{
134+
TryDisposeJvmThread(kvp.Key);
135+
}
136+
}
137+
138+
lock (_activeThreadGCTimerLock)
139+
{
140+
// Dispose of the timer if there are no threads to monitor.
141+
if (_activeThreadGCTimer != null && _activeThreads.IsEmpty)
142+
{
143+
_activeThreadGCTimer.Dispose();
144+
_activeThreadGCTimer = null;
145+
}
146+
}
147+
}
148+
}
149+
}

src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ internal class PayloadHelper
2727
private static readonly byte[] s_timestampTypeId = new[] { (byte)'t' };
2828
private static readonly byte[] s_jvmObjectTypeId = new[] { (byte)'j' };
2929
private static readonly byte[] s_byteArrayTypeId = new[] { (byte)'r' };
30-
private static readonly byte[] s_doubleArrayArrayTypeId = new[] { ( byte)'A' };
30+
private static readonly byte[] s_doubleArrayArrayTypeId = new[] { (byte)'A' };
3131
private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' };
3232
private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' };
3333
private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' };
@@ -39,6 +39,7 @@ internal class PayloadHelper
3939
internal static void BuildPayload(
4040
MemoryStream destination,
4141
bool isStaticMethod,
42+
int threadId,
4243
object classNameOrJvmObjectReference,
4344
string methodName,
4445
object[] args)
@@ -48,6 +49,7 @@ internal static void BuildPayload(
4849
destination.Position += sizeof(int);
4950

5051
SerDe.Write(destination, isStaticMethod);
52+
SerDe.Write(destination, threadId);
5153
SerDe.Write(destination, classNameOrJvmObjectReference.ToString());
5254
SerDe.Write(destination, methodName);
5355
SerDe.Write(destination, args.Length);
@@ -140,7 +142,7 @@ internal static void ConvertArgsToBytes(
140142
SerDe.Write(destination, d);
141143
}
142144
break;
143-
145+
144146
case double[][] argDoubleArrayArray:
145147
SerDe.Write(destination, s_doubleArrayArrayTypeId);
146148
SerDe.Write(destination, argDoubleArrayArray.Length);

0 commit comments

Comments
 (0)