Skip to content

Commit 510bcd8

Browse files
Support encryption for broadcast variables (#489)
1 parent 4fe7405 commit 510bcd8

File tree

4 files changed

+76
-17
lines changed

4 files changed

+76
-17
lines changed

azure-pipelines.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ variables:
1818
backwardCompatibleTestsToFilterOut: "(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestDataFrameGroupedMapUdf)&\
1919
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestDataFrameVectorUdf)&\
2020
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestDestroy)&\
21-
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestMultipleBroadcastWithoutEncryption)&\
21+
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestMultipleBroadcast)&\
2222
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestUnpersist)&\
2323
(FullyQualifiedName!=Microsoft.Spark.E2ETest.UdfTests.UdfComplexTypesTests.TestUdfWithArrayType)&\
2424
(FullyQualifiedName!=Microsoft.Spark.E2ETest.UdfTests.UdfComplexTypesTests.TestUdfWithArrayOfArrayType)&\

src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs

+17-6
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ public BroadcastTests(SparkFixture fixture)
3434
/// <summary>
3535
/// Test Broadcast support by using multiple broadcast variables in a UDF.
3636
/// </summary>
37-
[Fact]
38-
public void TestMultipleBroadcastWithoutEncryption()
37+
[Theory]
38+
[InlineData("true")]
39+
[InlineData("false")]
40+
public void TestMultipleBroadcast(string isEncryptionEnabled)
3941
{
42+
_spark.SparkContext.GetConf().Set("spark.io.encryption.enabled", isEncryptionEnabled);
4043
var obj1 = new TestBroadcastVariable(1, "first");
4144
var obj2 = new TestBroadcastVariable(2, "second");
4245
Broadcast<TestBroadcastVariable> bc1 = _spark.SparkContext.Broadcast(obj1);
@@ -49,15 +52,20 @@ public void TestMultipleBroadcastWithoutEncryption()
4952

5053
string[] actual = ToStringArray(_df.Select(udf(_df["_1"])));
5154
Assert.Equal(expected, actual);
55+
bc1.Destroy();
56+
bc2.Destroy();
5257
}
5358

5459
/// <summary>
5560
/// Test Broadcast.Destroy() that destroys all data and metadata related to the broadcast
5661
/// variable and makes it inaccessible from workers.
5762
/// </summary>
58-
[Fact]
59-
public void TestDestroy()
63+
[Theory]
64+
[InlineData("true")]
65+
[InlineData("false")]
66+
public void TestDestroy(string isEncryptionEnabled)
6067
{
68+
_spark.SparkContext.GetConf().Set("spark.io.encryption.enabled", isEncryptionEnabled);
6169
var obj1 = new TestBroadcastVariable(5, "destroy");
6270
Broadcast<TestBroadcastVariable> bc1 = _spark.SparkContext.Broadcast(obj1);
6371

@@ -96,9 +104,12 @@ public void TestDestroy()
96104
/// Test Broadcast.Unpersist() deletes cached copies of the broadcast on the executors. If
97105
/// the broadcast is used after unpersist is called, it is re-sent to the executors.
98106
/// </summary>
99-
[Fact]
100-
public void TestUnpersist()
107+
[Theory]
108+
[InlineData("true")]
109+
[InlineData("false")]
110+
public void TestUnpersist(string isEncryptionEnabled)
101111
{
112+
_spark.SparkContext.GetConf().Set("spark.io.encryption.enabled", isEncryptionEnabled);
102113
var obj = new TestBroadcastVariable(1, "unpersist");
103114
Broadcast<TestBroadcastVariable> bc = _spark.SparkContext.Broadcast(obj);
104115

src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs

+22-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Diagnostics;
67
using System.IO;
8+
using System.Net;
79
using System.Runtime.Serialization.Formatters.Binary;
810
using Microsoft.Spark.Interop.Ipc;
11+
using Microsoft.Spark.Network;
912

1013
namespace Microsoft.Spark.Worker.Processor
1114
{
@@ -25,6 +28,7 @@ internal BroadcastVariableProcessor(Version version)
2528
internal BroadcastVariables Process(Stream stream)
2629
{
2730
var broadcastVars = new BroadcastVariables();
31+
ISocketWrapper socket = null;
2832

2933
if (_version >= new Version(Versions.V2_3_2))
3034
{
@@ -37,7 +41,14 @@ internal BroadcastVariables Process(Stream stream)
3741
{
3842
broadcastVars.DecryptionServerPort = SerDe.ReadInt32(stream);
3943
broadcastVars.Secret = SerDe.ReadString(stream);
40-
// TODO: Handle the authentication.
44+
if (broadcastVars.Count > 0)
45+
{
46+
socket = SocketFactory.CreateSocket();
47+
socket.Connect(
48+
IPAddress.Loopback,
49+
broadcastVars.DecryptionServerPort,
50+
broadcastVars.Secret);
51+
}
4152
}
4253

4354
var formatter = new BinaryFormatter();
@@ -48,8 +59,15 @@ internal BroadcastVariables Process(Stream stream)
4859
{
4960
if (broadcastVars.DecryptionServerNeeded)
5061
{
51-
throw new NotImplementedException(
52-
"broadcastDecryptionServer is not implemented.");
62+
long readBid = SerDe.ReadInt64(socket.InputStream);
63+
if (bid != readBid)
64+
{
65+
throw new Exception("The Broadcast Id received from the encryption " +
66+
$"server {readBid} is different from the Broadcast Id received " +
67+
$"from the payload {bid}.");
68+
}
69+
object value = formatter.Deserialize(socket.InputStream);
70+
BroadcastRegistry.Add(bid, value);
5371
}
5472
else
5573
{
@@ -66,6 +84,7 @@ internal BroadcastVariables Process(Stream stream)
6684
BroadcastRegistry.Remove(bid);
6785
}
6886
}
87+
socket?.Dispose();
6988
return broadcastVars;
7089
}
7190
}

src/csharp/Microsoft.Spark/Broadcast.cs

+36-7
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
using System.Collections.Concurrent;
33
using System.Collections.Generic;
44
using System.IO;
5+
using System.Net;
56
using System.Runtime.Serialization;
67
using System.Runtime.Serialization.Formatters.Binary;
78
using System.Threading;
89
using Microsoft.Spark.Interop;
910
using Microsoft.Spark.Interop.Ipc;
11+
using Microsoft.Spark.Network;
1012
using Microsoft.Spark.Services;
1113

12-
1314
namespace Microsoft.Spark
1415
{
1516
/// <summary>
@@ -171,21 +172,49 @@ private JvmObjectReference CreateBroadcast_V2_3_2_AndAbove(
171172
bool encryptionEnabled = bool.Parse(
172173
sc.GetConf().Get("spark.io.encryption.enabled", "false"));
173174

175+
var _pythonBroadcast = (JvmObjectReference)javaSparkContext.Jvm.CallStaticJavaMethod(
176+
"org.apache.spark.api.python.PythonRDD",
177+
"setupBroadcast",
178+
_path);
179+
174180
if (encryptionEnabled)
175181
{
176-
throw new NotImplementedException("Broadcast encryption is not supported yet.");
182+
var pair = (JvmObjectReference[])_pythonBroadcast.Invoke("setupEncryptionServer");
183+
184+
using (ISocketWrapper socket = SocketFactory.CreateSocket())
185+
{
186+
socket.Connect(
187+
IPAddress.Loopback,
188+
(int)pair[0].Invoke("intValue"), // port number
189+
(string)pair[1].Invoke("toString")); // secret
190+
WriteToStream(value, socket.OutputStream);
191+
}
192+
_pythonBroadcast.Invoke("waitTillDataReceived");
177193
}
178194
else
179195
{
180196
WriteToFile(value);
181197
}
182198

183-
var pythonBroadcast = (JvmObjectReference)javaSparkContext.Jvm.CallStaticJavaMethod(
184-
"org.apache.spark.api.python.PythonRDD",
185-
"setupBroadcast",
186-
_path);
199+
return (JvmObjectReference)javaSparkContext.Invoke("broadcast", _pythonBroadcast);
200+
}
187201

188-
return (JvmObjectReference)javaSparkContext.Invoke("broadcast", pythonBroadcast);
202+
/// TODO: This is not performant in the case of Broadcast encryption as it writes to stream
203+
/// only after serializing the whole value, instead of serializing and writing in chunks
204+
/// like Python.
205+
/// <summary>
206+
/// Function to write the broadcast value into the stream.
207+
/// </summary>
208+
/// <param name="value">Broadcast value to be written to the stream</param>
209+
/// <param name="stream">Stream to write value to</param>
210+
private void WriteToStream(object value, Stream stream)
211+
{
212+
using var ms = new MemoryStream();
213+
Dump(value, ms);
214+
SerDe.Write(stream, ms.Length);
215+
ms.WriteTo(stream);
216+
// -1 length indicates to the receiving end that we're done.
217+
SerDe.Write(stream, -1);
189218
}
190219

191220
/// <summary>

0 commit comments

Comments
 (0)