@@ -18,23 +18,23 @@ import org.apache.spark.internal.Logging
18
18
* @param address The address of the Dotnet CallbackServer
19
19
* @param port The port of the Dotnet CallbackServer
20
20
*/
21
- class CallbackConnection (address : String , port : Int ) extends Logging {
21
+ class CallbackConnection (serDe : SerDe , address : String , port : Int ) extends Logging {
22
22
private [this ] val socket : Socket = new Socket (address, port)
23
23
private [this ] val inputStream : DataInputStream = new DataInputStream (socket.getInputStream)
24
24
private [this ] val outputStream : DataOutputStream = new DataOutputStream (socket.getOutputStream)
25
25
26
26
def send (
27
27
callbackId : Int ,
28
- writeBody : DataOutputStream => Unit ): Unit = {
28
+ writeBody : ( DataOutputStream , SerDe ) => Unit ): Unit = {
29
29
logInfo(s " Calling callback [callback id = $callbackId] ... " )
30
30
31
31
try {
32
- SerDe .writeInt(outputStream, CallbackFlags .CALLBACK )
33
- SerDe .writeInt(outputStream, callbackId)
32
+ serDe .writeInt(outputStream, CallbackFlags .CALLBACK )
33
+ serDe .writeInt(outputStream, callbackId)
34
34
35
35
val byteArrayOutputStream = new ByteArrayOutputStream ()
36
- writeBody(new DataOutputStream (byteArrayOutputStream))
37
- SerDe .writeInt(outputStream, byteArrayOutputStream.size)
36
+ writeBody(new DataOutputStream (byteArrayOutputStream), serDe )
37
+ serDe .writeInt(outputStream, byteArrayOutputStream.size)
38
38
byteArrayOutputStream.writeTo(outputStream);
39
39
} catch {
40
40
case e : Exception => {
@@ -44,7 +44,7 @@ class CallbackConnection(address: String, port: Int) extends Logging {
44
44
45
45
logInfo(s " Signaling END_OF_STREAM. " )
46
46
try {
47
- SerDe .writeInt(outputStream, CallbackFlags .END_OF_STREAM )
47
+ serDe .writeInt(outputStream, CallbackFlags .END_OF_STREAM )
48
48
outputStream.flush()
49
49
50
50
val endOfStreamResponse = readFlag(inputStream)
@@ -65,7 +65,7 @@ class CallbackConnection(address: String, port: Int) extends Logging {
65
65
66
66
def close (): Unit = {
67
67
try {
68
- SerDe .writeInt(outputStream, CallbackFlags .CLOSE )
68
+ serDe .writeInt(outputStream, CallbackFlags .CLOSE )
69
69
outputStream.flush()
70
70
} catch {
71
71
case e : Exception => logInfo(" Unable to send close to .NET callback server." , e)
@@ -95,9 +95,9 @@ class CallbackConnection(address: String, port: Int) extends Logging {
95
95
}
96
96
97
97
private def readFlag (inputStream : DataInputStream ): Int = {
98
- val callbackFlag = SerDe .readInt(inputStream)
98
+ val callbackFlag = serDe .readInt(inputStream)
99
99
if (callbackFlag == CallbackFlags .DOTNET_EXCEPTION_THROWN ) {
100
- val exceptionMessage = SerDe .readString(inputStream)
100
+ val exceptionMessage = serDe .readString(inputStream)
101
101
throw new DotnetException (exceptionMessage)
102
102
}
103
103
callbackFlag
@@ -109,4 +109,4 @@ class CallbackConnection(address: String, port: Int) extends Logging {
109
109
val DOTNET_EXCEPTION_THROWN : Int = - 3
110
110
val END_OF_STREAM : Int = - 4
111
111
}
112
- }
112
+ }
0 commit comments