Skip to content

Commit d2822c3

Browse files
authored
Implement ML Feature: SQLTransformer (#781)
1 parent 1767c3e commit d2822c3

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using System.IO;
7+
using Microsoft.Spark.ML.Feature;
8+
using Microsoft.Spark.Sql;
9+
using Microsoft.Spark.Sql.Types;
10+
using Microsoft.Spark.UnitTest.TestUtils;
11+
using Xunit;
12+
13+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
14+
{
15+
[Collection("Spark E2E Tests")]
16+
public class SQLTransformerTests : FeatureBaseTests<SQLTransformer>
17+
{
18+
private readonly SparkSession _spark;
19+
20+
public SQLTransformerTests(SparkFixture fixture) : base(fixture)
21+
{
22+
_spark = fixture.Spark;
23+
}
24+
25+
/// <summary>
26+
/// Create a <see cref="DataFrame"/>, create a <see cref="SQLTransformer"/> and test the
27+
/// available methods.
28+
/// </summary>
29+
[Fact]
30+
public void TestSQLTransformer()
31+
{
32+
DataFrame input = _spark.CreateDataFrame(
33+
new List<GenericRow>
34+
{
35+
new GenericRow(new object[] { 0, 1.0, 3.0 }),
36+
new GenericRow(new object[] { 2, 2.0, 5.0 })
37+
},
38+
new StructType(new List<StructField>
39+
{
40+
new StructField("id", new IntegerType()),
41+
new StructField("v1", new DoubleType()),
42+
new StructField("v2", new DoubleType())
43+
}));
44+
45+
string expectedUid = "theUid";
46+
string inputStatement = "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__";
47+
48+
SQLTransformer sqlTransformer = new SQLTransformer(expectedUid)
49+
.SetStatement(inputStatement);
50+
51+
string outputStatement = sqlTransformer.GetStatement();
52+
53+
DataFrame output = sqlTransformer.Transform(input);
54+
StructType outputSchema = sqlTransformer.TransformSchema(input.Schema());
55+
56+
Assert.Contains(output.Schema().Fields, (f => f.Name == "v3"));
57+
Assert.Contains(output.Schema().Fields, (f => f.Name == "v4"));
58+
Assert.Contains(outputSchema.Fields, (f => f.Name == "v3"));
59+
Assert.Contains(outputSchema.Fields, (f => f.Name == "v4"));
60+
Assert.Equal(inputStatement, outputStatement);
61+
62+
using (var tempDirectory = new TemporaryDirectory())
63+
{
64+
string savePath = Path.Join(tempDirectory.Path, "SQLTransformer");
65+
sqlTransformer.Save(savePath);
66+
67+
SQLTransformer loadedsqlTransformer = SQLTransformer.Load(savePath);
68+
Assert.Equal(sqlTransformer.Uid(), loadedsqlTransformer.Uid());
69+
}
70+
Assert.Equal(expectedUid, sqlTransformer.Uid());
71+
}
72+
}
73+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Spark.Interop;
6+
using Microsoft.Spark.Interop.Ipc;
7+
using Microsoft.Spark.Sql;
8+
using Microsoft.Spark.Sql.Types;
9+
10+
namespace Microsoft.Spark.ML.Feature
11+
{
12+
/// <summary>
13+
/// <see cref="SQLTransformer"/> implements the transformations which are defined by SQL statement.
14+
/// </summary>
15+
public class SQLTransformer : FeatureBase<SQLTransformer>, IJvmObjectReferenceProvider
16+
{
17+
private static readonly string s_sqlTransformerClassName =
18+
"org.apache.spark.ml.feature.SQLTransformer";
19+
20+
/// <summary>
21+
/// Create a <see cref="SQLTransformer"/> without any parameters.
22+
/// </summary>
23+
public SQLTransformer() : base(s_sqlTransformerClassName)
24+
{
25+
}
26+
27+
/// <summary>
28+
/// Create a <see cref="SQLTransformer"/> with a UID that is used to give the
29+
/// <see cref="SQLTransformer"/> a unique ID.
30+
/// </summary>
31+
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
32+
public SQLTransformer(string uid) : base(s_sqlTransformerClassName, uid)
33+
{
34+
}
35+
36+
internal SQLTransformer(JvmObjectReference jvmObject) : base(jvmObject)
37+
{
38+
}
39+
40+
JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;
41+
42+
/// <summary>
43+
/// Executes the <see cref="SQLTransformer"/> and transforms the DataFrame to include the new
44+
/// column.
45+
/// </summary>
46+
/// <param name="source">The DataFrame to transform</param>
47+
/// <returns>
48+
/// New <see cref="DataFrame"/> object with the source <see cref="DataFrame"/> transformed.
49+
/// </returns>
50+
public DataFrame Transform(DataFrame source) =>
51+
new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", source));
52+
53+
/// <summary>
54+
/// Executes the <see cref="SQLTransformer"/> and transforms the schema.
55+
/// </summary>
56+
/// <param name="value">The Schema to be transformed</param>
57+
/// <returns>
58+
/// New <see cref="StructType"/> object with the schema <see cref="StructType"/> transformed.
59+
/// </returns>
60+
public StructType TransformSchema(StructType value) =>
61+
new StructType(
62+
(JvmObjectReference)_jvmObject.Invoke(
63+
"transformSchema",
64+
DataType.FromJson(_jvmObject.Jvm, value.Json)));
65+
66+
/// <summary>
67+
/// Gets the statement.
68+
/// </summary>
69+
/// <returns>Statement</returns>
70+
public string GetStatement() => (string)_jvmObject.Invoke("getStatement");
71+
72+
/// <summary>
73+
/// Sets the statement to <see cref="SQLTransformer"/>.
74+
/// </summary>
75+
/// <param name="statement">SQL Statement</param>
76+
/// <returns>
77+
/// <see cref="SQLTransformer"/> with the statement set.
78+
/// </returns>
79+
public SQLTransformer SetStatement(string statement) =>
80+
WrapAsSQLTransformer((JvmObjectReference)_jvmObject.Invoke("setStatement", statement));
81+
82+
/// <summary>
83+
/// Loads the <see cref="SQLTransformer"/> that was previously saved using Save.
84+
/// </summary>
85+
/// <param name="path">The path the previous <see cref="SQLTransformer"/> was saved to</param>
86+
/// <returns>New <see cref="SQLTransformer"/> object, loaded from path</returns>
87+
public static SQLTransformer Load(string path) =>
88+
WrapAsSQLTransformer(
89+
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
90+
s_sqlTransformerClassName,
91+
"load",
92+
path));
93+
94+
private static SQLTransformer WrapAsSQLTransformer(object obj) =>
95+
new SQLTransformer((JvmObjectReference)obj);
96+
}
97+
}

0 commit comments

Comments
 (0)