diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BucketizerTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BucketizerTests.cs index e9193fd0b..949ba06da 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BucketizerTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/BucketizerTests.cs @@ -13,19 +13,24 @@ namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature { [Collection("Spark E2E Tests")] - public class BucketizerTests + public class BucketizerTests : FeatureBaseTests { private readonly SparkSession _spark; - public BucketizerTests(SparkFixture fixture) + public BucketizerTests(SparkFixture fixture) : base(fixture) { _spark = fixture.Spark; } + /// + /// Create a , create a and test the + /// available methods. Test the FeatureBase methods using . + /// [Fact] public void TestBucketizer() { - var expectedSplits = new double[] { double.MinValue, 0.0, 10.0, 50.0, double.MaxValue }; + var expectedSplits = + new double[] { double.MinValue, 0.0, 10.0, 50.0, double.MaxValue }; string expectedHandle = "skip"; string expectedUid = "uid"; @@ -60,18 +65,7 @@ public void TestBucketizer() Assert.Equal(bucketizer.Uid(), loadedBucketizer.Uid()); } - Assert.NotEmpty(bucketizer.ExplainParams()); - - Param handleInvalidParam = bucketizer.GetParam("handleInvalid"); - Assert.NotEmpty(handleInvalidParam.Doc); - Assert.NotEmpty(handleInvalidParam.Name); - Assert.Equal(handleInvalidParam.Parent, bucketizer.Uid()); - - Assert.NotEmpty(bucketizer.ExplainParam(handleInvalidParam)); - bucketizer.Set(handleInvalidParam, "keep"); - Assert.Equal("keep", bucketizer.GetHandleInvalid()); - - Assert.Equal("error", bucketizer.Clear(handleInvalidParam).GetHandleInvalid()); + TestFeatureBase(bucketizer, "handleInvalid", "keep"); } [Fact] diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerModelTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerModelTests.cs index 97458d173..e8ea1ade4 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerModelTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerModelTests.cs @@ -12,17 +12,17 @@ namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature { [Collection("Spark E2E Tests")] - public class CountVectorizerModelTests + public class CountVectorizerModelTests : FeatureBaseTests { private readonly SparkSession _spark; - public CountVectorizerModelTests(SparkFixture fixture) + public CountVectorizerModelTests(SparkFixture fixture) : base(fixture) { _spark = fixture.Spark; } /// - /// Test that we can create a CountVectorizerModel, pass in a specifc vocabulary to use + /// Test that we can create a CountVectorizerModel, pass in a specific vocabulary to use /// when creating the model. Verify the standard features methods as well as load/save. /// [Fact] @@ -68,6 +68,8 @@ public void TestCountVectorizerModel() Assert.IsType(countVectorizerModel.GetVocabSize()); Assert.NotEmpty(countVectorizerModel.ExplainParams()); Assert.NotEmpty(countVectorizerModel.ToString()); + + TestFeatureBase(countVectorizerModel, "minDF", 100); } } } diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerTests.cs index 9e022ba69..5d046dc87 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/CountVectorizerTests.cs @@ -13,11 +13,11 @@ namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature { [Collection("Spark E2E Tests")] - public class CountVectorizerTests + public class CountVectorizerTests : FeatureBaseTests { private readonly SparkSession _spark; - public CountVectorizerTests(SparkFixture fixture) + public CountVectorizerTests(SparkFixture fixture) : base(fixture) { _spark = fixture.Spark; } @@ -67,6 +67,8 @@ public void TestCountVectorizer() Assert.NotEmpty(countVectorizer.ExplainParams()); Assert.NotEmpty(countVectorizer.ToString()); + + TestFeatureBase(countVectorizer, "minDF", 0.4); } /// diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/FeatureBaseTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/FeatureBaseTests.cs new file mode 100644 index 000000000..01903e510 --- /dev/null +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/FeatureBaseTests.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Spark.ML.Feature; +using Microsoft.Spark.ML.Feature.Param; +using Microsoft.Spark.Sql; +using Xunit; + +namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature +{ + public class FeatureBaseTests + { + private readonly SparkSession _spark; + + protected FeatureBaseTests(SparkFixture fixture) + { + _spark = fixture.Spark; + } + + /// + /// Tests the common functionality across all ML.Feature classes. + /// + /// The object that implemented FeatureBase + /// The name of a parameter that can be set on this object + /// A parameter value that can be set on this object + public void TestFeatureBase( + FeatureBase testObject, + string paramName, + object paramValue) + { + Assert.NotEmpty(testObject.ExplainParams()); + + Param param = testObject.GetParam(paramName); + Assert.NotEmpty(param.Doc); + Assert.NotEmpty(param.Name); + Assert.Equal(param.Parent, testObject.Uid()); + + Assert.NotEmpty(testObject.ExplainParam(param)); + testObject.Set(param, paramValue); + Assert.IsAssignableFrom(testObject.Clear(param)); + + Assert.IsType(testObject.Uid()); + } + } +} diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/FeatureHasherTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/FeatureHasherTests.cs new file mode 100644 index 000000000..fe169a9f0 --- /dev/null +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/FeatureHasherTests.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.Spark.ML.Feature; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; +using Xunit; + +namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature +{ + [Collection("Spark E2E Tests")] + public class FeatureHasherTests : FeatureBaseTests + { + private readonly SparkSession _spark; + + public FeatureHasherTests(SparkFixture fixture) : base(fixture) + { + _spark = fixture.Spark; + } + + /// + /// Create a , create a and test the + /// available methods. Test the FeatureBase methods using . + /// + [Fact] + public void TestFeatureHasher() + { + DataFrame dataFrame = _spark.CreateDataFrame( + new List + { + new GenericRow(new object[] { 2.0D, true, "1", "foo" }), + new GenericRow(new object[] { 3.0D, false, "2", "bar" }) + }, + new StructType(new List + { + new StructField("real", new DoubleType()), + new StructField("bool", new BooleanType()), + new StructField("stringNum", new StringType()), + new StructField("string", new StringType()) + })); + + FeatureHasher hasher = new FeatureHasher() + .SetInputCols(new List() { "real", "bool", "stringNum", "string" }) + .SetOutputCol("features") + .SetCategoricalCols(new List() { "real", "string" }) + .SetNumFeatures(10); + + Assert.IsType(hasher.GetOutputCol()); + Assert.IsType(hasher.GetInputCols()); + Assert.IsType(hasher.GetCategoricalCols()); + Assert.IsType(hasher.GetNumFeatures()); + Assert.IsType(hasher.TransformSchema(dataFrame.Schema())); + Assert.IsType(hasher.Transform(dataFrame)); + + TestFeatureBase(hasher, "numFeatures", 1000); + } + } +} diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/HashingTFTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/HashingTFTests.cs index df459ed7a..246b4516e 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/HashingTFTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/HashingTFTests.cs @@ -11,11 +11,11 @@ namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature { [Collection("Spark E2E Tests")] - public class HashingTFTests + public class HashingTFTests : FeatureBaseTests { private readonly SparkSession _spark; - public HashingTFTests(SparkFixture fixture) + public HashingTFTests(SparkFixture fixture) : base(fixture) { _spark = fixture.Spark; } @@ -57,6 +57,8 @@ public void TestHashingTF() hashingTf.SetBinary(true); Assert.True(hashingTf.GetBinary()); + + TestFeatureBase(hashingTf, "numFeatures", 1000); } } } diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFModelTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFModelTests.cs index 202187809..1894373a6 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFModelTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFModelTests.cs @@ -11,11 +11,11 @@ namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature { [Collection("Spark E2E Tests")] - public class IDFModelTests + public class IDFModelTests : FeatureBaseTests { private readonly SparkSession _spark; - public IDFModelTests(SparkFixture fixture) + public IDFModelTests(SparkFixture fixture) : base(fixture) { _spark = fixture.Spark; } @@ -65,6 +65,8 @@ public void TestIDFModel() IDFModel loadedModel = IDFModel.Load(modelPath); Assert.Equal(idfModel.Uid(), loadedModel.Uid()); } + + TestFeatureBase(idfModel, "minDocFreq", 1000); } } } diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFTests.cs index 72da97887..64698ac9a 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/IDFTests.cs @@ -11,11 +11,11 @@ namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature { [Collection("Spark E2E Tests")] - public class IDFTests + public class IDFTests : FeatureBaseTests { private readonly SparkSession _spark; - public IDFTests(SparkFixture fixture) + public IDFTests(SparkFixture fixture) : base(fixture) { _spark = fixture.Spark; } @@ -44,6 +44,8 @@ public void TestIDFModel() IDF loadedIdf = IDF.Load(savePath); Assert.Equal(idf.Uid(), loadedIdf.Uid()); } + + TestFeatureBase(idf, "minDocFreq", 1000); } } } diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/TokenizerTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/TokenizerTests.cs index 4b1998f50..af76ac523 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/TokenizerTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/TokenizerTests.cs @@ -11,11 +11,11 @@ namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature { [Collection("Spark E2E Tests")] - public class TokenizerTests + public class TokenizerTests : FeatureBaseTests { private readonly SparkSession _spark; - public TokenizerTests(SparkFixture fixture) + public TokenizerTests(SparkFixture fixture) : base(fixture) { _spark = fixture.Spark; } @@ -50,6 +50,8 @@ public void TestTokenizer() } Assert.Equal(expectedUid, tokenizer.Uid()); + + TestFeatureBase(tokenizer, "inputCol", "input_col"); } } } diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecModelTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecModelTests.cs index a5227149b..04c7d7a79 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecModelTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecModelTests.cs @@ -11,11 +11,11 @@ namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature { [Collection("Spark E2E Tests")] - public class Word2VecModelTests + public class Word2VecModelTests : FeatureBaseTests { private readonly SparkSession _spark; - public Word2VecModelTests(SparkFixture fixture) + public Word2VecModelTests(SparkFixture fixture) : base(fixture) { _spark = fixture.Spark; } @@ -47,6 +47,8 @@ public void TestWord2VecModel() Word2VecModel loadedModel = Word2VecModel.Load(savePath); Assert.Equal(model.Uid(), loadedModel.Uid()); } + + TestFeatureBase(model, "maxIter", 2); } } } diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecTests.cs index 1d5da5335..1c36eb2c2 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/Word2VecTests.cs @@ -11,11 +11,11 @@ namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature { [Collection("Spark E2E Tests")] - public class Word2VecTests + public class Word2VecTests : FeatureBaseTests { private readonly SparkSession _spark; - public Word2VecTests(SparkFixture fixture) + public Word2VecTests(SparkFixture fixture) : base(fixture) { _spark = fixture.Spark; } @@ -67,6 +67,8 @@ public void TestWord2Vec() Word2Vec loadedWord2Vec = Word2Vec.Load(savePath); Assert.Equal(word2vec.Uid(), loadedWord2Vec.Uid()); } + + TestFeatureBase(word2vec, "maxIter", 2); } } } diff --git a/src/csharp/Microsoft.Spark/ML/Feature/FeatureHasher.cs b/src/csharp/Microsoft.Spark/ML/Feature/FeatureHasher.cs new file mode 100644 index 000000000..fb89b1051 --- /dev/null +++ b/src/csharp/Microsoft.Spark/ML/Feature/FeatureHasher.cs @@ -0,0 +1,147 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.Spark.Interop; +using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; + +namespace Microsoft.Spark.ML.Feature +{ + public class FeatureHasher: FeatureBase, IJvmObjectReferenceProvider + { + private static readonly string s_featureHasherClassName = + "org.apache.spark.ml.feature.FeatureHasher"; + + internal FeatureHasher() : base(s_featureHasherClassName) + { + } + + internal FeatureHasher(string uid) : base(s_featureHasherClassName, uid) + { + } + + internal FeatureHasher(JvmObjectReference jvmObject) : base(jvmObject) + { + } + + JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject; + + /// + /// Loads the that was previously saved using Save. + /// + /// + /// The path the previous was saved to. + /// + /// New object + public static FeatureHasher Load(string path) => + WrapAsFeatureHasher( + SparkEnvironment.JvmBridge.CallStaticJavaMethod( + s_featureHasherClassName, + "load", + path)); + + /// + /// Gets a list of the columns which have been specified as categorical columns. + /// + /// List of categorical columns, set by SetCategoricalCols + public IEnumerable GetCategoricalCols() => + (string[])_jvmObject.Invoke("getCategoricalCols"); + + /// + /// Marks columns as categorical columns. + /// + /// List of column names to mark as categorical columns + /// New object + public FeatureHasher SetCategoricalCols(IEnumerable value) => + WrapAsFeatureHasher(_jvmObject.Invoke("setCategoricalCols", value)); + + /// + /// Gets the columns that the should read from and convert into + /// hashes. This would have been set by SetInputCol. + /// + /// List of the input columns, set by SetInputCols + public IEnumerable GetInputCols() => (string[])_jvmObject.Invoke("getInputCols"); + + /// + /// Sets the columns that the should read from and convert into + /// hashes. + /// + /// The name of the column to as use the source of the hash + /// New object + public FeatureHasher SetInputCols(IEnumerable value) => + WrapAsFeatureHasher(_jvmObject.Invoke("setInputCols", value)); + + /// + /// Gets the number of features that should be used. Since a simple modulo is used to + /// transform the hash function to a column index, it is advisable to use a power of two + /// as the numFeatures parameter; otherwise the features will not be mapped evenly to the + /// columns. + /// + /// The number of features to be used + public int GetNumFeatures() => (int)_jvmObject.Invoke("getNumFeatures"); + + /// + /// Sets the number of features that should be used. Since a simple modulo is used to + /// transform the hash function to a column index, it is advisable to use a power of two as + /// the numFeatures parameter; otherwise the features will not be mapped evenly to the + /// columns. + /// + /// int value of number of features + /// New object + public FeatureHasher SetNumFeatures(int value) => + WrapAsFeatureHasher(_jvmObject.Invoke("setNumFeatures", value)); + + /// + /// Gets the name of the column the output data will be written to. This is set by + /// SetInputCol. + /// + /// string, the output column + public string GetOutputCol() => (string)_jvmObject.Invoke("getOutputCol"); + + /// + /// Sets the name of the new column in the created by Transform. + /// + /// The name of the new column which will contain the hash + /// New object + public FeatureHasher SetOutputCol(string value) => + WrapAsFeatureHasher(_jvmObject.Invoke("setOutputCol", value)); + + /// + /// Transforms the input . It is recommended that you validate that + /// the transform will succeed by calling TransformSchema. + /// + /// Input to transform + /// Transformed + public DataFrame Transform(DataFrame value) => + new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", value)); + + /// + /// Check transform validity and derive the output schema from the input schema. + /// + /// This checks for validity of interactions between parameters during Transform and + /// raises an exception if any parameter value is invalid. + /// + /// Typical implementation should first conduct verification on schema change and parameter + /// validity, including complex parameter interaction checks. + /// + /// + /// The of the which will be transformed. + /// + /// + /// The of the output schema that would have been derived from the + /// input schema, if Transform had been called. + /// + public StructType TransformSchema(StructType value) => + new StructType( + (JvmObjectReference)_jvmObject.Invoke( + "transformSchema", + DataType.FromJson(_jvmObject.Jvm, value.Json))); + + private static FeatureHasher WrapAsFeatureHasher(object obj) => + new FeatureHasher((JvmObjectReference)obj); + } +}