Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Included StringIndexer and StringIndexerModel along with related test… #804

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
{
[Collection("Spark E2E Tests")]
public class StringIndexerModelTests : FeatureBaseTests<StringIndexerModel>
{
private readonly SparkSession _spark;

public StringIndexerModelTests(SparkFixture fixture) : base(fixture)
{
_spark = fixture.Spark;
}

/// <summary>
/// Create a <see cref="DataFrame"/>, create a <see cref="StringIndexerModel"/> and test the
/// available methods.
/// </summary>
[Fact]
public void TestStringIndexerModel()
{
DataFrame input = _spark.CreateDataFrame(
new List<GenericRow>
{
new GenericRow(new object[] {0, "a"}),
new GenericRow(new object[] {1, "b"}),
new GenericRow(new object[] {2, "c"}),
new GenericRow(new object[] {3, "a"}),
new GenericRow(new object[] {4, "a"}),
new GenericRow(new object[] {5, "c"})
},
new StructType(new List<StructField>
{
new StructField("id", new IntegerType()),
new StructField("category", new StringType())
}));

string expectedUid = "theUid";
StringIndexer stringIndexer = new StringIndexer(expectedUid)
.SetInputCol("category")
.SetOutputCol("categoryIndex");

StringIndexerModel stringIndexerModel = stringIndexer.Fit(input);
DataFrame transformedDF = stringIndexerModel.Transform(input);
List<Row> observed = transformedDF.Select("category", new string[] { "categoryIndex" })
.Collect().ToList();
List<Row> expected = new List<Row>
{
new Row(new GenericRow(new object[] {"a", "0"})),
new Row(new GenericRow(new object[] {"b", "2"})),
new Row(new GenericRow(new object[] {"c", "1"})),
new Row(new GenericRow(new object[] {"a", "0"})),
new Row(new GenericRow(new object[] {"a", "0"})),
new Row(new GenericRow(new object[] {"c", "1"}))
};

Assert.Equal(expected, observed);
Assert.Equal("category", stringIndexer.GetInputCol());
Assert.Equal("categoryIndex", stringIndexer.GetOutputCol());
Assert.Equal(expectedUid, stringIndexer.Uid());

using (var tempDirectory = new TemporaryDirectory())
{
string savePath = Path.Join(tempDirectory.Path, "stringIndexerModel");
stringIndexerModel.Save(savePath);

StringIndexerModel loadedModel = StringIndexerModel.Load(savePath);
Assert.Equal(stringIndexerModel.Uid(), loadedModel.Uid());
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
{
[Collection("Spark E2E Tests")]
public class StringIndexerTests : FeatureBaseTests<StringIndexer>
{
private readonly SparkSession _spark;

public StringIndexerTests(SparkFixture fixture) : base(fixture)
{
_spark = fixture.Spark;
}

/// <summary>
/// Create a <see cref="DataFrame"/>, create a <see cref="StringIndexer"/> and test the
/// available methods.
/// </summary>
[Fact]
public void TestStringIndexer()
{
string expectedUid = "theUid";
StringIndexer stringIndexer = new StringIndexer(expectedUid)
.SetInputCol("category")
.SetOutputCol("categoryIndex");

Assert.Equal("category", stringIndexer.GetInputCol());
Assert.Equal("categoryIndex", stringIndexer.GetOutputCol());
Assert.Equal(expectedUid, stringIndexer.Uid());

using (var tempDirectory = new TemporaryDirectory())
{
string savePath = Path.Join(tempDirectory.Path, "stringIndexer");
stringIndexer.Save(savePath);

StringIndexer loadedstringIndexer = StringIndexer.Load(savePath);
Assert.Equal(stringIndexer.Uid(), loadedstringIndexer.Uid());
}
}
}
}
174 changes: 174 additions & 0 deletions src/csharp/Microsoft.Spark/ML/Feature/StringIndexer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;

namespace Microsoft.Spark.ML.Feature
{
/// <summary>
/// <see cref="StringIndexer"/> encodes a string column of labels to a column of label indices.
/// </summary>
public class StringIndexer : FeatureBase<StringIndexer>, IJvmObjectReferenceProvider
{
private static readonly string s_StringIndexerClassName =
"org.apache.spark.ml.feature.StringIndexer";

/// <summary>
/// Create a <see cref="StringIndexer"/> without any parameters.
/// </summary>
public StringIndexer() : base(s_StringIndexerClassName)
{
}

/// <summary>
/// Create a <see cref="StringIndexer"/> with a UID that is used to give the
/// <see cref="StringIndexer"/> a unique ID.
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public StringIndexer(string uid) : base(s_StringIndexerClassName, uid)
{
}

internal StringIndexer(JvmObjectReference jvmObject) : base(jvmObject)
{
}

JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;

/// <summary>
/// Executes the <see cref="StringIndexer"/> and transforms the schema.
/// </summary>
/// <param name="value">The Schema to be transformed</param>
/// <returns>
/// New <see cref="StructType"/> object with the schema <see cref="StructType"/> transformed.
/// </returns>
public StructType TransformSchema(StructType value) =>
new StructType(
(JvmObjectReference)_jvmObject.Invoke(
"transformSchema",
DataType.FromJson(_jvmObject.Jvm, value.Json)));

/// <summary>
/// Executes the <see cref="StringIndexer"/> and fits a model to the input data.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add the param and returns description?

public StringIndexerModel Fit(DataFrame source) =>
new StringIndexerModel((JvmObjectReference)_jvmObject.Invoke("fit", source));

/// <summary>
/// Gets the HandleInvalid.
/// </summary>
/// <returns>Handle Invalid option</returns>
public string GetHandleInvalid() => (string)_jvmObject.Invoke("handleInvalid");

/// <summary>
/// Sets the Handle Invalid option to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="handleInvalid">Handle Invalid option</param>
/// <returns>
/// <see cref="StringIndexer"/> with the Handle Invalid set.
/// </returns>
public StringIndexer SetHandleInvalid(string handleInvalid) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setHandleInvalid", handleInvalid));

/// <summary>
/// Gets the InputCol.
/// </summary>
/// <returns>Input Col option</returns>
public string GetInputCol() => (string)_jvmObject.Invoke("inputCol");

/// <summary>
/// Sets the Input Col option to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="inputCol">Input Col option</param>
/// <returns>
/// <see cref="StringIndexer"/> with the Input Col set.
/// </returns>
public StringIndexer SetInputCol(string inputCol) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setInputCol", inputCol));

/// <summary>
/// Gets the InputCols array.
/// </summary>
/// <returns>Input Cols array option</returns>
public string[] GetInputCols() => (string[])_jvmObject.Invoke("inputCols");

/// <summary>
/// Sets the Input Cols array option to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="inputCols">Input Cols array option</param>
/// <returns>
/// <see cref="StringIndexer"/> with the Input Cols array set.
/// </returns>
public StringIndexer SetInputCols(string[] inputCols) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setInputCol", inputCols));

/// <summary>
/// Gets the OutputCol.
/// </summary>
/// <returns>Output Col option</returns>
public string GetOutputCol() => (string)_jvmObject.Invoke("outputCol");

/// <summary>
/// Sets the Output Col option to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="outputCol">Output Col option</param>
/// <returns>
/// <see cref="StringIndexer"/> with the Output Col set.
/// </returns>
public StringIndexer SetOutputCol(string outputCol) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setOutputCol", outputCol));

/// <summary>
/// Gets the OutputCols array.
/// </summary>
/// <returns>Output Cols array option</returns>
public string[] GetOutputCols() => (string[])_jvmObject.Invoke("outputCols");

/// <summary>
/// Sets the Output Cols array option to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="outputCols">Output Cols array option</param>
/// <returns>
/// <see cref="StringIndexer"/> with the Output Cols array set.
/// </returns>
public StringIndexer SetOutputCols(string[] outputCols) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setOutputCol", outputCols));

/// <summary>
/// Gets the String Order Type.
/// </summary>
/// <returns>String Order Type</returns>
public string GetStringOrderType() => (string)_jvmObject.Invoke("stringOrderType");

/// <summary>
/// Sets the String Order Type to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="stringOrderType">String Order Type</param>
/// <returns>
/// <see cref="StringIndexer"/> with the String Order Type set.
/// </returns>
public StringIndexer SetStringOrderType(string stringOrderType) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setStringOrderType", stringOrderType));

/// <summary>
/// Loads the <see cref="StringIndexer"/> that was previously saved using Save.
/// </summary>
/// <param name="path">The path the previous <see cref="StringIndexer"/> was saved to</param>
/// <returns>New <see cref="StringIndexer"/> object, loaded from path</returns>
public static StringIndexer Load(string path) =>
WrapAsStringIndexer(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_StringIndexerClassName,
"load",
path));

private static StringIndexer WrapAsStringIndexer(object obj) =>
new StringIndexer((JvmObjectReference)obj);
}
}
Loading