diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index d95512842..5c7bec3d2 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -18,7 +18,7 @@ variables:
backwardCompatibleTestsToFilterOut: "(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestDataFrameGroupedMapUdf)&\
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestDataFrameVectorUdf)&\
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestDestroy)&\
- (FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestMultipleBroadcastWithoutEncryption)&\
+ (FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestMultipleBroadcast)&\
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestUnpersist)&\
(FullyQualifiedName!=Microsoft.Spark.E2ETest.UdfTests.UdfComplexTypesTests.TestUdfWithArrayType)&\
(FullyQualifiedName!=Microsoft.Spark.E2ETest.UdfTests.UdfComplexTypesTests.TestUdfWithArrayOfArrayType)&\
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs
index 511f5a122..e0443f04c 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs
@@ -34,9 +34,12 @@ public BroadcastTests(SparkFixture fixture)
///
/// Test Broadcast support by using multiple broadcast variables in a UDF.
///
- [Fact]
- public void TestMultipleBroadcastWithoutEncryption()
+ [Theory]
+ [InlineData("true")]
+ [InlineData("false")]
+ public void TestMultipleBroadcast(string isEncryptionEnabled)
{
+ _spark.SparkContext.GetConf().Set("spark.io.encryption.enabled", isEncryptionEnabled);
var obj1 = new TestBroadcastVariable(1, "first");
var obj2 = new TestBroadcastVariable(2, "second");
Broadcast bc1 = _spark.SparkContext.Broadcast(obj1);
@@ -49,15 +52,20 @@ public void TestMultipleBroadcastWithoutEncryption()
string[] actual = ToStringArray(_df.Select(udf(_df["_1"])));
Assert.Equal(expected, actual);
+ bc1.Destroy();
+ bc2.Destroy();
}
///
/// Test Broadcast.Destroy() that destroys all data and metadata related to the broadcast
/// variable and makes it inaccessible from workers.
///
- [Fact]
- public void TestDestroy()
+ [Theory]
+ [InlineData("true")]
+ [InlineData("false")]
+ public void TestDestroy(string isEncryptionEnabled)
{
+ _spark.SparkContext.GetConf().Set("spark.io.encryption.enabled", isEncryptionEnabled);
var obj1 = new TestBroadcastVariable(5, "destroy");
Broadcast bc1 = _spark.SparkContext.Broadcast(obj1);
@@ -96,9 +104,12 @@ public void TestDestroy()
/// Test Broadcast.Unpersist() deletes cached copies of the broadcast on the executors. If
/// the broadcast is used after unpersist is called, it is re-sent to the executors.
///
- [Fact]
- public void TestUnpersist()
+ [Theory]
+ [InlineData("true")]
+ [InlineData("false")]
+ public void TestUnpersist(string isEncryptionEnabled)
{
+ _spark.SparkContext.GetConf().Set("spark.io.encryption.enabled", isEncryptionEnabled);
var obj = new TestBroadcastVariable(1, "unpersist");
Broadcast bc = _spark.SparkContext.Broadcast(obj);
diff --git a/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs b/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs
index bf8f48ed8..e3bc16df6 100644
--- a/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs
+++ b/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs
@@ -3,9 +3,12 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Diagnostics;
using System.IO;
+using System.Net;
using System.Runtime.Serialization.Formatters.Binary;
using Microsoft.Spark.Interop.Ipc;
+using Microsoft.Spark.Network;
namespace Microsoft.Spark.Worker.Processor
{
@@ -25,6 +28,7 @@ internal BroadcastVariableProcessor(Version version)
internal BroadcastVariables Process(Stream stream)
{
var broadcastVars = new BroadcastVariables();
+ ISocketWrapper socket = null;
if (_version >= new Version(Versions.V2_3_2))
{
@@ -37,7 +41,14 @@ internal BroadcastVariables Process(Stream stream)
{
broadcastVars.DecryptionServerPort = SerDe.ReadInt32(stream);
broadcastVars.Secret = SerDe.ReadString(stream);
- // TODO: Handle the authentication.
+ if (broadcastVars.Count > 0)
+ {
+ socket = SocketFactory.CreateSocket();
+ socket.Connect(
+ IPAddress.Loopback,
+ broadcastVars.DecryptionServerPort,
+ broadcastVars.Secret);
+ }
}
var formatter = new BinaryFormatter();
@@ -48,8 +59,15 @@ internal BroadcastVariables Process(Stream stream)
{
if (broadcastVars.DecryptionServerNeeded)
{
- throw new NotImplementedException(
- "broadcastDecryptionServer is not implemented.");
+ long readBid = SerDe.ReadInt64(socket.InputStream);
+ if (bid != readBid)
+ {
+ throw new Exception("The Broadcast Id received from the encryption " +
+ $"server {readBid} is different from the Broadcast Id received " +
+ $"from the payload {bid}.");
+ }
+ object value = formatter.Deserialize(socket.InputStream);
+ BroadcastRegistry.Add(bid, value);
}
else
{
@@ -66,6 +84,7 @@ internal BroadcastVariables Process(Stream stream)
BroadcastRegistry.Remove(bid);
}
}
+ socket?.Dispose();
return broadcastVars;
}
}
diff --git a/src/csharp/Microsoft.Spark/Broadcast.cs b/src/csharp/Microsoft.Spark/Broadcast.cs
index 99025b7c2..f0ea061fb 100644
--- a/src/csharp/Microsoft.Spark/Broadcast.cs
+++ b/src/csharp/Microsoft.Spark/Broadcast.cs
@@ -2,14 +2,15 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
+using System.Net;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using System.Threading;
using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
+using Microsoft.Spark.Network;
using Microsoft.Spark.Services;
-
namespace Microsoft.Spark
{
///
@@ -171,21 +172,49 @@ private JvmObjectReference CreateBroadcast_V2_3_2_AndAbove(
bool encryptionEnabled = bool.Parse(
sc.GetConf().Get("spark.io.encryption.enabled", "false"));
+ var _pythonBroadcast = (JvmObjectReference)javaSparkContext.Jvm.CallStaticJavaMethod(
+ "org.apache.spark.api.python.PythonRDD",
+ "setupBroadcast",
+ _path);
+
if (encryptionEnabled)
{
- throw new NotImplementedException("Broadcast encryption is not supported yet.");
+ var pair = (JvmObjectReference[])_pythonBroadcast.Invoke("setupEncryptionServer");
+
+ using (ISocketWrapper socket = SocketFactory.CreateSocket())
+ {
+ socket.Connect(
+ IPAddress.Loopback,
+ (int)pair[0].Invoke("intValue"), // port number
+ (string)pair[1].Invoke("toString")); // secret
+ WriteToStream(value, socket.OutputStream);
+ }
+ _pythonBroadcast.Invoke("waitTillDataReceived");
}
else
{
WriteToFile(value);
}
- var pythonBroadcast = (JvmObjectReference)javaSparkContext.Jvm.CallStaticJavaMethod(
- "org.apache.spark.api.python.PythonRDD",
- "setupBroadcast",
- _path);
+ return (JvmObjectReference)javaSparkContext.Invoke("broadcast", _pythonBroadcast);
+ }
- return (JvmObjectReference)javaSparkContext.Invoke("broadcast", pythonBroadcast);
+ /// TODO: This is not performant in the case of Broadcast encryption as it writes to stream
+ /// only after serializing the whole value, instead of serializing and writing in chunks
+ /// like Python.
+ ///
+ /// Function to write the broadcast value into the stream.
+ ///
+ /// Broadcast value to be written to the stream
+ /// Stream to write value to
+ private void WriteToStream(object value, Stream stream)
+ {
+ using var ms = new MemoryStream();
+ Dump(value, ms);
+ SerDe.Write(stream, ms.Length);
+ ms.WriteTo(stream);
+ // -1 length indicates to the receiving end that we're done.
+ SerDe.Write(stream, -1);
}
///