-
Notifications
You must be signed in to change notification settings - Fork 322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Included StringIndexer and StringIndexerModel along with related test… #804
Open
ramanathanv
wants to merge
16
commits into
dotnet:main
Choose a base branch
from
ramanathanv:StringIndexer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
1f4bcdb
Included StringIndexer and StringIndexerModel along with related test…
65f43d0
Corrected issue in test case
f3b287c
Corrected issue in test case
89d1b98
Merge branch 'master' into StringIndexer
ramanathanv 0d04752
Merge branch 'master' into StringIndexer
imback82 ac68589
Merge branch 'master' into StringIndexer
ramanathanv f78d4ce
Corrected the test case
6cd1a7c
Changed FirstorDefault to Where
6ead393
Modified List datatype
fa1add4
Corrected the internal property names
24b3331
Merge branch 'master' into StringIndexer
ramanathanv 643789c
Changed List comparison
a5007c9
Merge branch 'StringIndexer' of https://github.com/ramanathanv/spark …
4cc337f
Reverted direct List Check
59ea4e5
Merge branch 'master' into StringIndexer
ramanathanv 641e1d3
Merge branch 'master' into StringIndexer
ramanathanv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
83 changes: 83 additions & 0 deletions
83
src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StringIndexerModelTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Linq; | ||
using Microsoft.Spark.ML.Feature; | ||
using Microsoft.Spark.Sql; | ||
using Microsoft.Spark.Sql.Types; | ||
using Microsoft.Spark.UnitTest.TestUtils; | ||
using Xunit; | ||
|
||
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature | ||
{ | ||
[Collection("Spark E2E Tests")] | ||
public class StringIndexerModelTests : FeatureBaseTests<StringIndexerModel> | ||
{ | ||
private readonly SparkSession _spark; | ||
|
||
public StringIndexerModelTests(SparkFixture fixture) : base(fixture) | ||
{ | ||
_spark = fixture.Spark; | ||
} | ||
|
||
/// <summary> | ||
/// Create a <see cref="DataFrame"/>, create a <see cref="StringIndexerModel"/> and test the | ||
/// available methods. | ||
/// </summary> | ||
[Fact] | ||
public void TestStringIndexerModel() | ||
{ | ||
DataFrame input = _spark.CreateDataFrame( | ||
new List<GenericRow> | ||
{ | ||
new GenericRow(new object[] {0, "a"}), | ||
new GenericRow(new object[] {1, "b"}), | ||
new GenericRow(new object[] {2, "c"}), | ||
new GenericRow(new object[] {3, "a"}), | ||
new GenericRow(new object[] {4, "a"}), | ||
new GenericRow(new object[] {5, "c"}) | ||
}, | ||
new StructType(new List<StructField> | ||
{ | ||
new StructField("id", new IntegerType()), | ||
new StructField("category", new StringType()) | ||
})); | ||
|
||
string expectedUid = "theUid"; | ||
StringIndexer stringIndexer = new StringIndexer(expectedUid) | ||
.SetInputCol("category") | ||
.SetOutputCol("categoryIndex"); | ||
|
||
StringIndexerModel stringIndexerModel = stringIndexer.Fit(input); | ||
DataFrame transformedDF = stringIndexerModel.Transform(input); | ||
List<Row> observed = transformedDF.Select("category", new string[] { "categoryIndex" }) | ||
.Collect().ToList(); | ||
List<Row> expected = new List<Row> | ||
{ | ||
new Row(new GenericRow(new object[] {"a", "0"})), | ||
new Row(new GenericRow(new object[] {"b", "2"})), | ||
new Row(new GenericRow(new object[] {"c", "1"})), | ||
new Row(new GenericRow(new object[] {"a", "0"})), | ||
new Row(new GenericRow(new object[] {"a", "0"})), | ||
new Row(new GenericRow(new object[] {"c", "1"})) | ||
}; | ||
|
||
Assert.Equal(expected, observed); | ||
Assert.Equal("category", stringIndexer.GetInputCol()); | ||
Assert.Equal("categoryIndex", stringIndexer.GetOutputCol()); | ||
Assert.Equal(expectedUid, stringIndexer.Uid()); | ||
|
||
using (var tempDirectory = new TemporaryDirectory()) | ||
{ | ||
string savePath = Path.Join(tempDirectory.Path, "stringIndexerModel"); | ||
stringIndexerModel.Save(savePath); | ||
|
||
StringIndexerModel loadedModel = StringIndexerModel.Load(savePath); | ||
Assert.Equal(stringIndexerModel.Uid(), loadedModel.Uid()); | ||
} | ||
} | ||
} | ||
} |
52 changes: 52 additions & 0 deletions
52
src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StringIndexerTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Linq; | ||
using Microsoft.Spark.ML.Feature; | ||
using Microsoft.Spark.Sql; | ||
using Microsoft.Spark.Sql.Types; | ||
using Microsoft.Spark.UnitTest.TestUtils; | ||
using Xunit; | ||
|
||
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature | ||
{ | ||
[Collection("Spark E2E Tests")] | ||
public class StringIndexerTests : FeatureBaseTests<StringIndexer> | ||
{ | ||
private readonly SparkSession _spark; | ||
|
||
public StringIndexerTests(SparkFixture fixture) : base(fixture) | ||
{ | ||
_spark = fixture.Spark; | ||
} | ||
|
||
/// <summary> | ||
/// Create a <see cref="DataFrame"/>, create a <see cref="StringIndexer"/> and test the | ||
/// available methods. | ||
/// </summary> | ||
[Fact] | ||
public void TestStringIndexer() | ||
{ | ||
string expectedUid = "theUid"; | ||
StringIndexer stringIndexer = new StringIndexer(expectedUid) | ||
.SetInputCol("category") | ||
.SetOutputCol("categoryIndex"); | ||
|
||
Assert.Equal("category", stringIndexer.GetInputCol()); | ||
Assert.Equal("categoryIndex", stringIndexer.GetOutputCol()); | ||
Assert.Equal(expectedUid, stringIndexer.Uid()); | ||
|
||
using (var tempDirectory = new TemporaryDirectory()) | ||
{ | ||
string savePath = Path.Join(tempDirectory.Path, "stringIndexer"); | ||
stringIndexer.Save(savePath); | ||
|
||
StringIndexer loadedstringIndexer = StringIndexer.Load(savePath); | ||
Assert.Equal(stringIndexer.Uid(), loadedstringIndexer.Uid()); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.Spark.Interop; | ||
using Microsoft.Spark.Interop.Ipc; | ||
using Microsoft.Spark.Sql; | ||
using Microsoft.Spark.Sql.Types; | ||
|
||
namespace Microsoft.Spark.ML.Feature | ||
{ | ||
/// <summary> | ||
/// <see cref="StringIndexer"/> encodes a string column of labels to a column of label indices. | ||
/// </summary> | ||
public class StringIndexer : FeatureBase<StringIndexer>, IJvmObjectReferenceProvider | ||
{ | ||
private static readonly string s_StringIndexerClassName = | ||
"org.apache.spark.ml.feature.StringIndexer"; | ||
|
||
/// <summary> | ||
/// Create a <see cref="StringIndexer"/> without any parameters. | ||
/// </summary> | ||
public StringIndexer() : base(s_StringIndexerClassName) | ||
{ | ||
} | ||
|
||
/// <summary> | ||
/// Create a <see cref="StringIndexer"/> with a UID that is used to give the | ||
/// <see cref="StringIndexer"/> a unique ID. | ||
/// </summary> | ||
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param> | ||
public StringIndexer(string uid) : base(s_StringIndexerClassName, uid) | ||
{ | ||
} | ||
|
||
internal StringIndexer(JvmObjectReference jvmObject) : base(jvmObject) | ||
{ | ||
} | ||
|
||
JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject; | ||
|
||
/// <summary> | ||
/// Executes the <see cref="StringIndexer"/> and transforms the schema. | ||
/// </summary> | ||
/// <param name="value">The Schema to be transformed</param> | ||
/// <returns> | ||
/// New <see cref="StructType"/> object with the schema <see cref="StructType"/> transformed. | ||
/// </returns> | ||
public StructType TransformSchema(StructType value) => | ||
new StructType( | ||
(JvmObjectReference)_jvmObject.Invoke( | ||
"transformSchema", | ||
DataType.FromJson(_jvmObject.Jvm, value.Json))); | ||
|
||
/// <summary> | ||
/// Executes the <see cref="StringIndexer"/> and fits a model to the input data. | ||
/// </summary> | ||
/// <param name="source"></param> | ||
/// <returns></returns> | ||
public StringIndexerModel Fit(DataFrame source) => | ||
new StringIndexerModel((JvmObjectReference)_jvmObject.Invoke("fit", source)); | ||
|
||
/// <summary> | ||
/// Gets the HandleInvalid. | ||
/// </summary> | ||
/// <returns>Handle Invalid option</returns> | ||
public string GetHandleInvalid() => (string)_jvmObject.Invoke("handleInvalid"); | ||
|
||
/// <summary> | ||
/// Sets the Handle Invalid option to <see cref="StringIndexer"/>. | ||
/// </summary> | ||
/// <param name="handleInvalid">Handle Invalid option</param> | ||
/// <returns> | ||
/// <see cref="StringIndexer"/> with the Handle Invalid set. | ||
/// </returns> | ||
public StringIndexer SetHandleInvalid(string handleInvalid) => | ||
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setHandleInvalid", handleInvalid)); | ||
|
||
/// <summary> | ||
/// Gets the InputCol. | ||
/// </summary> | ||
/// <returns>Input Col option</returns> | ||
public string GetInputCol() => (string)_jvmObject.Invoke("inputCol"); | ||
|
||
/// <summary> | ||
/// Sets the Input Col option to <see cref="StringIndexer"/>. | ||
/// </summary> | ||
/// <param name="inputCol">Input Col option</param> | ||
/// <returns> | ||
/// <see cref="StringIndexer"/> with the Input Col set. | ||
/// </returns> | ||
public StringIndexer SetInputCol(string inputCol) => | ||
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setInputCol", inputCol)); | ||
|
||
/// <summary> | ||
/// Gets the InputCols array. | ||
/// </summary> | ||
/// <returns>Input Cols array option</returns> | ||
public string[] GetInputCols() => (string[])_jvmObject.Invoke("inputCols"); | ||
|
||
/// <summary> | ||
/// Sets the Input Cols array option to <see cref="StringIndexer"/>. | ||
/// </summary> | ||
/// <param name="inputCols">Input Cols array option</param> | ||
/// <returns> | ||
/// <see cref="StringIndexer"/> with the Input Cols array set. | ||
/// </returns> | ||
public StringIndexer SetInputCols(string[] inputCols) => | ||
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setInputCol", inputCols)); | ||
|
||
/// <summary> | ||
/// Gets the OutputCol. | ||
/// </summary> | ||
/// <returns>Output Col option</returns> | ||
public string GetOutputCol() => (string)_jvmObject.Invoke("outputCol"); | ||
|
||
/// <summary> | ||
/// Sets the Output Col option to <see cref="StringIndexer"/>. | ||
/// </summary> | ||
/// <param name="outputCol">Output Col option</param> | ||
/// <returns> | ||
/// <see cref="StringIndexer"/> with the Output Col set. | ||
/// </returns> | ||
public StringIndexer SetOutputCol(string outputCol) => | ||
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setOutputCol", outputCol)); | ||
|
||
/// <summary> | ||
/// Gets the OutputCols array. | ||
/// </summary> | ||
/// <returns>Output Cols array option</returns> | ||
public string[] GetOutputCols() => (string[])_jvmObject.Invoke("outputCols"); | ||
|
||
/// <summary> | ||
/// Sets the Output Cols array option to <see cref="StringIndexer"/>. | ||
/// </summary> | ||
/// <param name="outputCols">Output Cols array option</param> | ||
/// <returns> | ||
/// <see cref="StringIndexer"/> with the Output Cols array set. | ||
/// </returns> | ||
public StringIndexer SetOutputCols(string[] outputCols) => | ||
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setOutputCol", outputCols)); | ||
|
||
/// <summary> | ||
/// Gets the String Order Type. | ||
/// </summary> | ||
/// <returns>String Order Type</returns> | ||
public string GetStringOrderType() => (string)_jvmObject.Invoke("stringOrderType"); | ||
|
||
/// <summary> | ||
/// Sets the String Order Type to <see cref="StringIndexer"/>. | ||
/// </summary> | ||
/// <param name="stringOrderType">String Order Type</param> | ||
/// <returns> | ||
/// <see cref="StringIndexer"/> with the String Order Type set. | ||
/// </returns> | ||
public StringIndexer SetStringOrderType(string stringOrderType) => | ||
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setStringOrderType", stringOrderType)); | ||
|
||
/// <summary> | ||
/// Loads the <see cref="StringIndexer"/> that was previously saved using Save. | ||
/// </summary> | ||
/// <param name="path">The path the previous <see cref="StringIndexer"/> was saved to</param> | ||
/// <returns>New <see cref="StringIndexer"/> object, loaded from path</returns> | ||
public static StringIndexer Load(string path) => | ||
WrapAsStringIndexer( | ||
SparkEnvironment.JvmBridge.CallStaticJavaMethod( | ||
s_StringIndexerClassName, | ||
"load", | ||
path)); | ||
|
||
private static StringIndexer WrapAsStringIndexer(object obj) => | ||
new StringIndexer((JvmObjectReference)obj); | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add the
param
andreturns
description?