Skip to content

Commit 1ab206d

Browse files
Expose new SparkSession, DataFrame, and DataFrameStatFunctions APIs introduced in Spark 3.0 (#647)
1 parent 46b0d77 commit 1ab206d

File tree

5 files changed

+146
-10
lines changed

5 files changed

+146
-10
lines changed

src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameFunctionsTests.cs

+17
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.Collections.Generic;
6+
using Microsoft.Spark.E2ETest.Utils;
67
using Microsoft.Spark.Sql;
8+
using static Microsoft.Spark.Sql.Functions;
79
using Xunit;
810

911
namespace Microsoft.Spark.E2ETest.IpcTests
@@ -91,5 +93,20 @@ public void TestDataFrameStatFunctionSignatures()
9193

9294
df = stat.SampleBy("age", new Dictionary<int, double> { { 1, 0.5 } }, 100);
9395
}
96+
97+
/// <summary>
98+
/// Test signatures for APIs introduced in Spark 3.0.*.
99+
/// </summary>
100+
[SkipIfSparkVersionIsLessThan(Versions.V3_0_0)]
101+
public void TestSignaturesV3_0_X()
102+
{
103+
DataFrameStatFunctions stat = _df.Stat();
104+
Column col = Column("age");
105+
106+
Assert.IsType<DataFrame>(stat.SampleBy(
107+
col,
108+
new Dictionary<int, double> { { 1, 0.5 } },
109+
100));
110+
}
94111
}
95112
}

src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs

+12
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,18 @@ public void TestSignaturesV3_X_X()
712712
IEnumerable<Row> actual = df.ToLocalIterator(true).ToArray();
713713
IEnumerable<Row> expected = data.Select(r => new Row(r.Values, schema));
714714
Assert.Equal(expected, actual);
715+
716+
Assert.IsType<DataFrame>(df.Observe("metrics", Count("Name").As("CountNames")));
717+
718+
Assert.IsType<Row[]>(_df.Tail(1).ToArray());
719+
720+
_df.PrintSchema(1);
721+
722+
_df.Explain("simple");
723+
_df.Explain("extended");
724+
_df.Explain("codegen");
725+
_df.Explain("cost");
726+
_df.Explain("formatted");
715727
}
716728
}
717729
}

src/csharp/Microsoft.Spark/Sql/DataFrame.cs

+77-10
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ public void PrintSchema() =>
7070
Console.WriteLine(
7171
(string)((JvmObjectReference)_jvmObject.Invoke("schema")).Invoke("treeString"));
7272

73+
/// <summary>
74+
/// Prints the schema up to the given level to the console in a nice tree format.
75+
/// </summary>
76+
[Since(Versions.V3_0_0)]
77+
public void PrintSchema(int level)
78+
{
79+
var schema = (JvmObjectReference)_jvmObject.Invoke("schema");
80+
Console.WriteLine((string)schema.Invoke("treeString", level));
81+
}
82+
7383
/// <summary>
7484
/// Prints the plans (logical and physical) to the console for debugging purposes.
7585
/// </summary>
@@ -80,6 +90,30 @@ public void Explain(bool extended = false)
8090
Console.WriteLine((string)execution.Invoke(extended ? "toString" : "simpleString"));
8191
}
8292

93+
/// <summary>
94+
/// Prints the plans (logical and physical) with a format specified by a given explain
95+
/// mode.
96+
///
97+
/// </summary>
98+
/// <param name="mode">Specifies the expected output format of plans.
99+
/// 1. `simple` Print only a physical plan.
100+
/// 2. `extended`: Print both logical and physical plans.
101+
/// 3. `codegen`: Print a physical plan and generated codes if they are available.
102+
/// 4. `cost`: Print a logical plan and statistics if they are available.
103+
/// 5. `formatted`: Split explain output into two sections: a physical plan outline and
104+
/// node details.
105+
/// </param>
106+
[Since(Versions.V3_0_0)]
107+
public void Explain(string mode)
108+
{
109+
var execution = (JvmObjectReference)_jvmObject.Invoke("queryExecution");
110+
var explainMode = (JvmObjectReference)_jvmObject.Jvm.CallStaticJavaMethod(
111+
"org.apache.spark.sql.execution.ExplainMode",
112+
"fromString",
113+
mode);
114+
Console.WriteLine((string)execution.Invoke("explainString", explainMode));
115+
}
116+
83117
/// <summary>
84118
/// Returns all column names and their data types as an IEnumerable of Tuples.
85119
/// </summary>
@@ -480,6 +514,27 @@ public RelationalGroupedDataset Cube(string column, params string[] columns) =>
480514
public DataFrame Agg(Column expr, params Column[] exprs) =>
481515
WrapAsDataFrame(_jvmObject.Invoke("agg", expr, exprs));
482516

517+
/// <summary>
518+
/// Define (named) metrics to observe on the Dataset. This method returns an 'observed'
519+
/// DataFrame that returns the same result as the input, with the following guarantees:
520+
///
521+
/// 1. It will compute the defined aggregates(metrics) on all the data that is flowing
522+
/// through the Dataset at that point.
523+
/// 2. It will report the value of the defined aggregate columns as soon as we reach a
524+
/// completion point.A completion point is either the end of a query(batch mode) or the end
525+
/// of a streaming epoch. The value of the aggregates only reflects the data processed
526+
/// since the previous completion point.
527+
///
528+
/// Please note that continuous execution is currently not supported.
529+
/// </summary>
530+
/// <param name="name">Named metrics to observe</param>
531+
/// <param name="expr">Defined aggregate to observe</param>
532+
/// <param name="exprs">Defined aggregates to observe</param>
533+
/// <returns>DataFrame object</returns>
534+
[Since(Versions.V3_0_0)]
535+
public DataFrame Observe(string name, Column expr, params Column[] exprs) =>
536+
WrapAsDataFrame(_jvmObject.Invoke("observe", name, expr, exprs));
537+
483538
/// <summary>
484539
/// Returns a new `DataFrame` by taking the first `number` rows.
485540
/// </summary>
@@ -702,6 +757,17 @@ public DataFrame Summary(params string[] statistics) =>
702757
/// <returns>First `n` rows</returns>
703758
public IEnumerable<Row> Take(int n) => Head(n);
704759

760+
/// <summary>
761+
/// Returns the last `n` rows in the `DataFrame`.
762+
/// </summary>
763+
/// <param name="n">Number of rows</param>
764+
/// <returns>Last `n` rows</returns>
765+
[Since(Versions.V3_0_0)]
766+
public IEnumerable<Row> Tail(int n)
767+
{
768+
return GetRows("tailToPython", n);
769+
}
770+
705771
/// <summary>
706772
/// Returns an array that contains all rows in this `DataFrame`.
707773
/// </summary>
@@ -929,16 +995,15 @@ public DataStreamWriter WriteStream() =>
929995
new DataStreamWriter((JvmObjectReference)_jvmObject.Invoke("writeStream"), this);
930996

931997
/// <summary>
932-
/// Returns row objects based on the function (either "toPythonIterator" or
933-
/// "collectToPython").
998+
/// Returns row objects based on the function (either "toPythonIterator",
999+
/// "collectToPython", or "tailToPython").
9341000
/// </summary>
935-
/// <param name="funcName">
936-
/// The name of the function to call, either "toPythonIterator" or "collectToPython".
937-
/// </param>
938-
/// <returns><see cref="Row"/> objects</returns>
939-
private IEnumerable<Row> GetRows(string funcName)
1001+
/// <param name="funcName">String name of function to call</param>
1002+
/// <param name="args">Arguments to the function</param>
1003+
/// <returns>IEnumerable of Rows from Spark</returns>
1004+
private IEnumerable<Row> GetRows(string funcName, params object[] args)
9401005
{
941-
(int port, string secret, _) = GetConnectionInfo(funcName);
1006+
(int port, string secret, _) = GetConnectionInfo(funcName, args);
9421007
using ISocketWrapper socket = SocketFactory.CreateSocket();
9431008
socket.Connect(IPAddress.Loopback, port, secret);
9441009
foreach (Row row in new RowCollector().Collect(socket))
@@ -952,9 +1017,11 @@ private IEnumerable<Row> GetRows(string funcName)
9521017
/// used for connecting with Spark to receive rows for this `DataFrame`.
9531018
/// </summary>
9541019
/// <returns>A tuple of port number, secret string, and JVM socket auth server.</returns>
955-
private (int, string, JvmObjectReference) GetConnectionInfo(string funcName)
1020+
private (int, string, JvmObjectReference) GetConnectionInfo(
1021+
string funcName,
1022+
params object[] args)
9561023
{
957-
object result = _jvmObject.Invoke(funcName);
1024+
object result = _jvmObject.Invoke(funcName, args);
9581025
Version version = SparkEnvironment.SparkVersion;
9591026
return (version.Major, version.Minor, version.Build) switch
9601027
{

src/csharp/Microsoft.Spark/Sql/DataFrameStatFunctions.cs

+16
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,22 @@ public DataFrame SampleBy<T>(
121121
long seed) =>
122122
WrapAsDataFrame(_jvmObject.Invoke("sampleBy", columnName, fractions, seed));
123123

124+
/// <summary>
125+
/// Returns a stratified sample without replacement based on the fraction given
126+
/// on each stratum.
127+
/// </summary>
128+
/// <typeparam name="T">Stratum type</typeparam>
129+
/// <param name="column">Column that defines strata</param>
130+
/// <param name="fractions">
131+
/// Sampling fraction for each stratum. If a stratum is not specified, we treat
132+
/// its fraction as zero.
133+
/// </param>
134+
/// <param name="seed">Random seed</param>
135+
/// <returns>DataFrame object</returns>
136+
[Since(Versions.V3_0_0)]
137+
public DataFrame SampleBy<T>(Column column, IDictionary<T, double> fractions, long seed) =>
138+
WrapAsDataFrame(_jvmObject.Invoke("sampleBy", column, fractions, seed));
139+
124140
private DataFrame WrapAsDataFrame(object obj) => new DataFrame((JvmObjectReference)obj);
125141
}
126142
}

src/csharp/Microsoft.Spark/Sql/SparkSession.cs

+24
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,30 @@ public DataFrame CreateDataFrame(IEnumerable<Timestamp> data) =>
255255
public DataFrame Sql(string sqlText) =>
256256
new DataFrame((JvmObjectReference)_jvmObject.Invoke("sql", sqlText));
257257

258+
/// <summary>
259+
/// Execute an arbitrary string command inside an external execution engine rather than
260+
/// Spark. This could be useful when user wants to execute some commands out of Spark. For
261+
/// example, executing custom DDL/DML command for JDBC, creating index for ElasticSearch,
262+
/// creating cores for Solr and so on.
263+
/// The command will be eagerly executed after this method is called and the returned
264+
/// DataFrame will contain the output of the command(if any).
265+
/// </summary>
266+
/// <param name="runner">The class name of the runner that implements
267+
/// `ExternalCommandRunner`</param>
268+
/// <param name="command">The target command to be executed</param>
269+
/// <param name="options">The options for the runner</param>
270+
/// <returns>>DataFrame object</returns>
271+
[Since(Versions.V3_0_0)]
272+
public DataFrame ExecuteCommand(
273+
string runner,
274+
string command,
275+
Dictionary<string, string> options) =>
276+
new DataFrame((JvmObjectReference)_jvmObject.Invoke(
277+
"executeCommand",
278+
runner,
279+
command,
280+
options));
281+
258282
/// <summary>
259283
/// Returns a DataFrameReader that can be used to read non-streaming data in
260284
/// as a DataFrame.

0 commit comments

Comments
 (0)