Skip to content

Add NGram #734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 5, 2021
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.IO;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
{
/// <summary>
/// Test suite for <see cref="NGram"/> class.
/// </summary>
[Collection("Spark E2E Tests")]
public class NGramTests : FeatureBaseTests<NGram>
{
private readonly SparkSession _spark;

public NGramTests(SparkFixture fixture) : base(fixture)
{
_spark = fixture.Spark;
}

/// <summary>
/// Test case to test the methods in <see cref="NGram"/> class.
/// </summary>
[Fact]
public void TestNGram()
{
string expectedUid = "theUid";
string expectedInputCol = "input_col";
string expectedOutputCol = "output_col";
int expectedN = 2;

DataFrame input = _spark.Sql("SELECT split('Hi I heard about Spark', ' ') as input_col");

NGram nGram = new NGram(expectedUid)
.SetInputCol(expectedInputCol)
.SetOutputCol(expectedOutputCol)
.SetN(expectedN);

StructType outputSchema = nGram.TransformSchema(input.Schema());

DataFrame output = nGram.Transform(input);

Assert.Contains(output.Schema().Fields, (f => f.Name == expectedOutputCol));
Assert.Contains(outputSchema.Fields, (f => f.Name == expectedOutputCol));
Assert.Equal(expectedInputCol, nGram.GetInputCol());
Assert.Equal(expectedOutputCol, nGram.GetOutputCol());
Assert.Equal(expectedN, nGram.GetN());

using (var tempDirectory = new TemporaryDirectory())
{
string savePath = Path.Join(tempDirectory.Path, "NGram");
nGram.Save(savePath);

NGram loadedNGram = NGram.Load(savePath);
Assert.Equal(nGram.Uid(), loadedNGram.Uid());
}

Assert.Equal(expectedUid, nGram.Uid());

TestFeatureBase(nGram, "inputCol", "input_col");
}
}
}
131 changes: 131 additions & 0 deletions src/csharp/Microsoft.Spark/ML/Feature/NGram.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;

namespace Microsoft.Spark.ML.Feature
{
/// <summary>
/// Class <see cref="NGram"/> transformer that converts the input array of strings into
/// an array of n-grams. Null values in the input array are ignored. It returns an array
/// of n-grams where each n-gram is represented by a space-separated string of words.
/// </summary>
public class NGram : FeatureBase<NGram>, IJvmObjectReferenceProvider
{
private static readonly string s_nGramClassName =
"org.apache.spark.ml.feature.NGram";

/// <summary>
/// Create a <see cref="NGram"/> without any parameters.
/// </summary>
public NGram() : base(s_nGramClassName)
{
}

/// <summary>
/// Create a <see cref="NGram"/> with a UID that is used to give the
/// <see cref="NGram"/> a unique ID.
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.
/// </param>
public NGram(string uid) : base(s_nGramClassName, uid)
{
}

internal NGram(JvmObjectReference jvmObject) : base(jvmObject)
{
}

JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;

/// <summary>
/// Gets the column that the <see cref="NGram"/> should read from.
/// </summary>
/// <returns>string, input column</returns>
public string GetInputCol() => (string)_jvmObject.Invoke("getInputCol");

/// <summary>
/// Sets the column that the <see cref="NGram"/> should read from.
/// </summary>
/// <param name="value">The name of the column to as the source</param>
/// <returns>New <see cref="NGram"/> object</returns>
public NGram SetInputCol(string value) => WrapAsNGram(_jvmObject.Invoke("setInputCol", value));

/// <summary>
/// Gets the output column that the <see cref="NGram"/> writes.
/// </summary>
/// <returns>string, the output column</returns>
public string GetOutputCol() => (string)_jvmObject.Invoke("getOutputCol");

/// <summary>
/// Sets the output column that the <see cref="NGram"/> writes.
/// </summary>
/// <param name="value">The name of the new column</param>
/// <returns>New <see cref="NGram"/> object</returns>
public NGram SetOutputCol(string value) => WrapAsNGram(_jvmObject.Invoke("setOutputCol", value));

/// <summary>
/// Gets N value for <see cref="NGram"/>.
/// </summary>
/// <returns>int, N value</returns>
public int GetN() => (int)_jvmObject.Invoke("getN");

/// <summary>
/// Sets N value for <see cref="NGram"/>.
/// </summary>
/// <param name="value">N value</param>
/// <returns>New <see cref="NGram"/> object</returns>
public NGram SetN(int value) => WrapAsNGram(_jvmObject.Invoke("setN", value));

/// <summary>
/// Executes the <see cref="NGram"/> and transforms the DataFrame to include the new
/// column.
/// </summary>
/// <param name="source">The DataFrame to transform</param>
/// <returns>
/// New <see cref="DataFrame"/> object with the source <see cref="DataFrame"/> transformed.
/// </returns>
public DataFrame Transform(DataFrame source) =>
new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", source));

/// <summary>
/// Check transform validity and derive the output schema from the input schema.
///
/// This checks for validity of interactions between parameters during Transform and
/// raises an exception if any parameter value is invalid.
///
/// Typical implementation should first conduct verification on schema change and parameter
/// validity, including complex parameter interaction checks.
/// </summary>
/// <param name="value">
/// The <see cref="StructType"/> of the <see cref="DataFrame"/> which will be transformed.
/// </param>
/// <returns>
/// The <see cref="StructType"/> of the output schema that would have been derived from the
/// input schema, if Transform had been called.
/// </returns>
public StructType TransformSchema(StructType value) =>
new StructType(
(JvmObjectReference)_jvmObject.Invoke(
"transformSchema",
DataType.FromJson(_jvmObject.Jvm, value.Json)));

/// <summary>
/// Loads the <see cref="NGram"/> that was previously saved using Save.
/// </summary>
/// <param name="path">The path the previous <see cref="NGram"/> was saved to</param>
/// <returns>New <see cref="NGram"/> object, loaded from path</returns>
public static NGram Load(string path) =>
WrapAsNGram(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_nGramClassName,
"load",
path));

private static NGram WrapAsNGram(object obj) => new NGram((JvmObjectReference)obj);
}
}