diff --git a/azure-pipelines.yml b/azure-pipelines.yml index d95512842..5c7bec3d2 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -18,7 +18,7 @@ variables: backwardCompatibleTestsToFilterOut: "(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestDataFrameGroupedMapUdf)&\ (FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestDataFrameVectorUdf)&\ (FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestDestroy)&\ - (FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestMultipleBroadcastWithoutEncryption)&\ + (FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestMultipleBroadcast)&\ (FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.BroadcastTests.TestUnpersist)&\ (FullyQualifiedName!=Microsoft.Spark.E2ETest.UdfTests.UdfComplexTypesTests.TestUdfWithArrayType)&\ (FullyQualifiedName!=Microsoft.Spark.E2ETest.UdfTests.UdfComplexTypesTests.TestUdfWithArrayOfArrayType)&\ diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs index 511f5a122..e0443f04c 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/BroadcastTests.cs @@ -34,9 +34,12 @@ public BroadcastTests(SparkFixture fixture) /// /// Test Broadcast support by using multiple broadcast variables in a UDF. /// - [Fact] - public void TestMultipleBroadcastWithoutEncryption() + [Theory] + [InlineData("true")] + [InlineData("false")] + public void TestMultipleBroadcast(string isEncryptionEnabled) { + _spark.SparkContext.GetConf().Set("spark.io.encryption.enabled", isEncryptionEnabled); var obj1 = new TestBroadcastVariable(1, "first"); var obj2 = new TestBroadcastVariable(2, "second"); Broadcast bc1 = _spark.SparkContext.Broadcast(obj1); @@ -49,15 +52,20 @@ public void TestMultipleBroadcastWithoutEncryption() string[] actual = ToStringArray(_df.Select(udf(_df["_1"]))); Assert.Equal(expected, actual); + bc1.Destroy(); + bc2.Destroy(); } /// /// Test Broadcast.Destroy() that destroys all data and metadata related to the broadcast /// variable and makes it inaccessible from workers. /// - [Fact] - public void TestDestroy() + [Theory] + [InlineData("true")] + [InlineData("false")] + public void TestDestroy(string isEncryptionEnabled) { + _spark.SparkContext.GetConf().Set("spark.io.encryption.enabled", isEncryptionEnabled); var obj1 = new TestBroadcastVariable(5, "destroy"); Broadcast bc1 = _spark.SparkContext.Broadcast(obj1); @@ -96,9 +104,12 @@ public void TestDestroy() /// Test Broadcast.Unpersist() deletes cached copies of the broadcast on the executors. If /// the broadcast is used after unpersist is called, it is re-sent to the executors. /// - [Fact] - public void TestUnpersist() + [Theory] + [InlineData("true")] + [InlineData("false")] + public void TestUnpersist(string isEncryptionEnabled) { + _spark.SparkContext.GetConf().Set("spark.io.encryption.enabled", isEncryptionEnabled); var obj = new TestBroadcastVariable(1, "unpersist"); Broadcast bc = _spark.SparkContext.Broadcast(obj); diff --git a/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs b/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs index bf8f48ed8..e3bc16df6 100644 --- a/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs +++ b/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs @@ -3,9 +3,12 @@ // See the LICENSE file in the project root for more information. using System; +using System.Diagnostics; using System.IO; +using System.Net; using System.Runtime.Serialization.Formatters.Binary; using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Network; namespace Microsoft.Spark.Worker.Processor { @@ -25,6 +28,7 @@ internal BroadcastVariableProcessor(Version version) internal BroadcastVariables Process(Stream stream) { var broadcastVars = new BroadcastVariables(); + ISocketWrapper socket = null; if (_version >= new Version(Versions.V2_3_2)) { @@ -37,7 +41,14 @@ internal BroadcastVariables Process(Stream stream) { broadcastVars.DecryptionServerPort = SerDe.ReadInt32(stream); broadcastVars.Secret = SerDe.ReadString(stream); - // TODO: Handle the authentication. + if (broadcastVars.Count > 0) + { + socket = SocketFactory.CreateSocket(); + socket.Connect( + IPAddress.Loopback, + broadcastVars.DecryptionServerPort, + broadcastVars.Secret); + } } var formatter = new BinaryFormatter(); @@ -48,8 +59,15 @@ internal BroadcastVariables Process(Stream stream) { if (broadcastVars.DecryptionServerNeeded) { - throw new NotImplementedException( - "broadcastDecryptionServer is not implemented."); + long readBid = SerDe.ReadInt64(socket.InputStream); + if (bid != readBid) + { + throw new Exception("The Broadcast Id received from the encryption " + + $"server {readBid} is different from the Broadcast Id received " + + $"from the payload {bid}."); + } + object value = formatter.Deserialize(socket.InputStream); + BroadcastRegistry.Add(bid, value); } else { @@ -66,6 +84,7 @@ internal BroadcastVariables Process(Stream stream) BroadcastRegistry.Remove(bid); } } + socket?.Dispose(); return broadcastVars; } } diff --git a/src/csharp/Microsoft.Spark/Broadcast.cs b/src/csharp/Microsoft.Spark/Broadcast.cs index 99025b7c2..f0ea061fb 100644 --- a/src/csharp/Microsoft.Spark/Broadcast.cs +++ b/src/csharp/Microsoft.Spark/Broadcast.cs @@ -2,14 +2,15 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; +using System.Net; using System.Runtime.Serialization; using System.Runtime.Serialization.Formatters.Binary; using System.Threading; using Microsoft.Spark.Interop; using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Network; using Microsoft.Spark.Services; - namespace Microsoft.Spark { /// @@ -171,21 +172,49 @@ private JvmObjectReference CreateBroadcast_V2_3_2_AndAbove( bool encryptionEnabled = bool.Parse( sc.GetConf().Get("spark.io.encryption.enabled", "false")); + var _pythonBroadcast = (JvmObjectReference)javaSparkContext.Jvm.CallStaticJavaMethod( + "org.apache.spark.api.python.PythonRDD", + "setupBroadcast", + _path); + if (encryptionEnabled) { - throw new NotImplementedException("Broadcast encryption is not supported yet."); + var pair = (JvmObjectReference[])_pythonBroadcast.Invoke("setupEncryptionServer"); + + using (ISocketWrapper socket = SocketFactory.CreateSocket()) + { + socket.Connect( + IPAddress.Loopback, + (int)pair[0].Invoke("intValue"), // port number + (string)pair[1].Invoke("toString")); // secret + WriteToStream(value, socket.OutputStream); + } + _pythonBroadcast.Invoke("waitTillDataReceived"); } else { WriteToFile(value); } - var pythonBroadcast = (JvmObjectReference)javaSparkContext.Jvm.CallStaticJavaMethod( - "org.apache.spark.api.python.PythonRDD", - "setupBroadcast", - _path); + return (JvmObjectReference)javaSparkContext.Invoke("broadcast", _pythonBroadcast); + } - return (JvmObjectReference)javaSparkContext.Invoke("broadcast", pythonBroadcast); + /// TODO: This is not performant in the case of Broadcast encryption as it writes to stream + /// only after serializing the whole value, instead of serializing and writing in chunks + /// like Python. + /// + /// Function to write the broadcast value into the stream. + /// + /// Broadcast value to be written to the stream + /// Stream to write value to + private void WriteToStream(object value, Stream stream) + { + using var ms = new MemoryStream(); + Dump(value, ms); + SerDe.Write(stream, ms.Length); + ms.WriteTo(stream); + // -1 length indicates to the receiving end that we're done. + SerDe.Write(stream, -1); } ///