Skip to content

Commit 9239e72

Browse files
authored
Implement ML/CountVectorizer and ML/CountVectorizerModel (#608)
1 parent 6064831 commit 9239e72

File tree

4 files changed

+522
-0
lines changed

4 files changed

+522
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using System.IO;
7+
using Microsoft.Spark.ML.Feature;
8+
using Microsoft.Spark.Sql;
9+
using Microsoft.Spark.UnitTest.TestUtils;
10+
using Xunit;
11+
12+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
13+
{
14+
[Collection("Spark E2E Tests")]
15+
public class CountVectorizerModelTests
16+
{
17+
private readonly SparkSession _spark;
18+
19+
public CountVectorizerModelTests(SparkFixture fixture)
20+
{
21+
_spark = fixture.Spark;
22+
}
23+
24+
/// <summary>
25+
/// Test that we can create a CountVectorizerModel, pass in a specifc vocabulary to use
26+
/// when creating the model. Verify the standard features methods as well as load/save.
27+
/// </summary>
28+
[Fact]
29+
public void TestCountVectorizerModel()
30+
{
31+
const string inputColumn = "input";
32+
const string outputColumn = "output";
33+
const double minTf = 10.0;
34+
const bool binary = false;
35+
36+
var vocabulary = new List<string>()
37+
{
38+
"hello",
39+
"I",
40+
"AM",
41+
"TO",
42+
"TOKENIZE"
43+
};
44+
45+
var countVectorizerModel = new CountVectorizerModel(vocabulary);
46+
47+
Assert.IsType<CountVectorizerModel>(new CountVectorizerModel("my-uid", vocabulary));
48+
49+
countVectorizerModel = countVectorizerModel
50+
.SetInputCol(inputColumn)
51+
.SetOutputCol(outputColumn)
52+
.SetMinTF(minTf)
53+
.SetBinary(binary);
54+
55+
Assert.Equal(inputColumn, countVectorizerModel.GetInputCol());
56+
Assert.Equal(outputColumn, countVectorizerModel.GetOutputCol());
57+
Assert.Equal(minTf, countVectorizerModel.GetMinTF());
58+
Assert.Equal(binary, countVectorizerModel.GetBinary());
59+
using (var tempDirectory = new TemporaryDirectory())
60+
{
61+
string savePath = Path.Join(tempDirectory.Path, "countVectorizerModel");
62+
countVectorizerModel.Save(savePath);
63+
64+
CountVectorizerModel loadedModel = CountVectorizerModel.Load(savePath);
65+
Assert.Equal(countVectorizerModel.Uid(), loadedModel.Uid());
66+
}
67+
68+
Assert.IsType<int>(countVectorizerModel.GetVocabSize());
69+
Assert.NotEmpty(countVectorizerModel.ExplainParams());
70+
Assert.NotEmpty(countVectorizerModel.ToString());
71+
}
72+
}
73+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.IO;
7+
using Microsoft.Spark.E2ETest.Utils;
8+
using Microsoft.Spark.ML.Feature;
9+
using Microsoft.Spark.Sql;
10+
using Microsoft.Spark.UnitTest.TestUtils;
11+
using Xunit;
12+
13+
namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
14+
{
15+
[Collection("Spark E2E Tests")]
16+
public class CountVectorizerTests
17+
{
18+
private readonly SparkSession _spark;
19+
20+
public CountVectorizerTests(SparkFixture fixture)
21+
{
22+
_spark = fixture.Spark;
23+
}
24+
25+
/// <summary>
26+
/// Test that we can create a CountVectorizer. Verify the standard features methods as well
27+
/// as load/save.
28+
/// </summary>
29+
[Fact]
30+
public void TestCountVectorizer()
31+
{
32+
DataFrame input = _spark.Sql("SELECT array('hello', 'I', 'AM', 'a', 'string', 'TO', " +
33+
"'TOKENIZE') as input from range(100)");
34+
35+
const string inputColumn = "input";
36+
const string outputColumn = "output";
37+
const double minDf = 1;
38+
const double minTf = 10;
39+
const int vocabSize = 10000;
40+
const bool binary = false;
41+
42+
var countVectorizer = new CountVectorizer();
43+
44+
countVectorizer
45+
.SetInputCol(inputColumn)
46+
.SetOutputCol(outputColumn)
47+
.SetMinDF(minDf)
48+
.SetMinTF(minTf)
49+
.SetVocabSize(vocabSize);
50+
51+
Assert.IsType<CountVectorizerModel>(countVectorizer.Fit(input));
52+
Assert.Equal(inputColumn, countVectorizer.GetInputCol());
53+
Assert.Equal(outputColumn, countVectorizer.GetOutputCol());
54+
Assert.Equal(minDf, countVectorizer.GetMinDF());
55+
Assert.Equal(minTf, countVectorizer.GetMinTF());
56+
Assert.Equal(vocabSize, countVectorizer.GetVocabSize());
57+
Assert.Equal(binary, countVectorizer.GetBinary());
58+
59+
using (var tempDirectory = new TemporaryDirectory())
60+
{
61+
string savePath = Path.Join(tempDirectory.Path, "countVectorizer");
62+
countVectorizer.Save(savePath);
63+
64+
CountVectorizer loadedVectorizer = CountVectorizer.Load(savePath);
65+
Assert.Equal(countVectorizer.Uid(), loadedVectorizer.Uid());
66+
}
67+
68+
Assert.NotEmpty(countVectorizer.ExplainParams());
69+
Assert.NotEmpty(countVectorizer.ToString());
70+
}
71+
72+
/// <summary>
73+
/// Test signatures for APIs introduced in Spark 2.4.*.
74+
/// </summary>
75+
[SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
76+
public void TestSignaturesV2_4_X()
77+
{
78+
const double maxDf = 100;
79+
CountVectorizer countVectorizer = new CountVectorizer().SetMaxDF(maxDf);
80+
Assert.Equal(maxDf, countVectorizer.GetMaxDF());
81+
}
82+
}
83+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Spark.Interop;
6+
using Microsoft.Spark.Interop.Ipc;
7+
using Microsoft.Spark.Sql;
8+
9+
namespace Microsoft.Spark.ML.Feature
10+
{
11+
public class CountVectorizer : FeatureBase<CountVectorizer>, IJvmObjectReferenceProvider
12+
{
13+
private static readonly string s_countVectorizerClassName =
14+
"org.apache.spark.ml.feature.CountVectorizer";
15+
16+
/// <summary>
17+
/// Creates a <see cref="CountVectorizer"/> without any parameters.
18+
/// </summary>
19+
public CountVectorizer() : base(s_countVectorizerClassName)
20+
{
21+
}
22+
23+
/// <summary>
24+
/// Creates a <see cref="CountVectorizer"/> with a UID that is used to give the
25+
/// <see cref="CountVectorizer"/> a unique ID.
26+
/// </summary>
27+
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
28+
public CountVectorizer(string uid) : base(s_countVectorizerClassName, uid)
29+
{
30+
}
31+
32+
internal CountVectorizer(JvmObjectReference jvmObject) : base(jvmObject)
33+
{
34+
}
35+
36+
JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;
37+
38+
/// <summary>Fits a model to the input data.</summary>
39+
/// <param name="dataFrame">The <see cref="DataFrame"/> to fit the model to.</param>
40+
/// <returns><see cref="CountVectorizerModel"/></returns>
41+
public CountVectorizerModel Fit(DataFrame dataFrame) =>
42+
new CountVectorizerModel((JvmObjectReference)_jvmObject.Invoke("fit", dataFrame));
43+
44+
/// <summary>
45+
/// Loads the <see cref="CountVectorizer"/> that was previously saved using Save.
46+
/// </summary>
47+
/// <param name="path">
48+
/// The path the previous <see cref="CountVectorizer"/> was saved to.
49+
/// </param>
50+
/// <returns>New <see cref="CountVectorizer"/> object</returns>
51+
public static CountVectorizer Load(string path) =>
52+
WrapAsCountVectorizer((JvmObjectReference)
53+
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
54+
s_countVectorizerClassName,"load", path));
55+
56+
/// <summary>
57+
/// Gets the binary toggle to control the output vector values. If True, all nonzero counts
58+
/// (after minTF filter applied) are set to 1. This is useful for discrete probabilistic
59+
/// models that model binary events rather than integer counts. Default: false
60+
/// </summary>
61+
/// <returns>boolean</returns>
62+
public bool GetBinary() => (bool)_jvmObject.Invoke("getBinary");
63+
64+
/// <summary>
65+
/// Sets the binary toggle to control the output vector values. If True, all nonzero counts
66+
/// (after minTF filter applied) are set to 1. This is useful for discrete probabilistic
67+
/// models that model binary events rather than integer counts. Default: false
68+
/// </summary>
69+
/// <param name="value">Turn the binary toggle on or off</param>
70+
/// <returns><see cref="CountVectorizer"/> with the new binary toggle value set</returns>
71+
public CountVectorizer SetBinary(bool value) =>
72+
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setBinary", value));
73+
74+
/// <summary>
75+
/// Gets the column that the <see cref="CountVectorizer"/> should read from and convert
76+
/// into buckets. This would have been set by SetInputCol.
77+
/// </summary>
78+
/// <returns>The input column of type string</returns>
79+
public string GetInputCol() => (string)_jvmObject.Invoke("getInputCol");
80+
81+
/// <summary>
82+
/// Sets the column that the <see cref="CountVectorizer"/> should read from.
83+
/// </summary>
84+
/// <param name="value">The name of the column to use as the source.</param>
85+
/// <returns><see cref="CountVectorizer"/> with the input column set</returns>
86+
public CountVectorizer SetInputCol(string value) =>
87+
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setInputCol", value));
88+
89+
/// <summary>
90+
/// Gets the name of the new column the <see cref="CountVectorizer"/> creates in the
91+
/// DataFrame.
92+
/// </summary>
93+
/// <returns>The name of the output column.</returns>
94+
public string GetOutputCol() => (string)_jvmObject.Invoke("getOutputCol");
95+
96+
/// <summary>
97+
/// Sets the name of the new column the <see cref="CountVectorizer"/> creates in the
98+
/// DataFrame.
99+
/// </summary>
100+
/// <param name="value">The name of the output column which will be created.</param>
101+
/// <returns>New <see cref="CountVectorizer"/> with the output column set</returns>
102+
public CountVectorizer SetOutputCol(string value) =>
103+
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setOutputCol", value));
104+
105+
/// <summary>
106+
/// Gets the maximum number of different documents a term could appear in to be included in
107+
/// the vocabulary. A term that appears more than the threshold will be ignored. If this is
108+
/// an integer greater than or equal to 1, this specifies the maximum number of documents
109+
/// the term could appear in; if this is a double in [0,1), then this specifies the maximum
110+
/// fraction of documents the term could appear in.
111+
/// </summary>
112+
/// <returns>The maximum document term frequency</returns>
113+
[Since(Versions.V2_4_0)]
114+
public double GetMaxDF() => (double)_jvmObject.Invoke("getMaxDF");
115+
116+
/// <summary>
117+
/// Sets the maximum number of different documents a term could appear in to be included in
118+
/// the vocabulary. A term that appears more than the threshold will be ignored. If this is
119+
/// an integer greater than or equal to 1, this specifies the maximum number of documents
120+
/// the term could appear in; if this is a double in [0,1), then this specifies the maximum
121+
/// fraction of documents the term could appear in.
122+
/// </summary>
123+
/// <param name="value">The maximum document term frequency</param>
124+
/// <returns>New <see cref="CountVectorizer"/> with the max df value set</returns>
125+
[Since(Versions.V2_4_0)]
126+
public CountVectorizer SetMaxDF(double value) =>
127+
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setMaxDF", value));
128+
129+
/// <summary>
130+
/// Gets the minimum number of different documents a term must appear in to be included in
131+
/// the vocabulary. If this is an integer greater than or equal to 1, this specifies the
132+
/// number of documents the term must appear in; if this is a double in [0,1), then this
133+
/// specifies the fraction of documents.
134+
/// </summary>
135+
/// <returns>The minimum document term frequency</returns>
136+
public double GetMinDF() => (double)_jvmObject.Invoke("getMinDF");
137+
138+
/// <summary>
139+
/// Sets the minimum number of different documents a term must appear in to be included in
140+
/// the vocabulary. If this is an integer greater than or equal to 1, this specifies the
141+
/// number of documents the term must appear in; if this is a double in [0,1), then this
142+
/// specifies the fraction of documents.
143+
/// </summary>
144+
/// <param name="value">The minimum document term frequency</param>
145+
/// <returns>New <see cref="CountVectorizer"/> with the min df value set</returns>
146+
public CountVectorizer SetMinDF(double value) =>
147+
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setMinDF", value));
148+
149+
/// <summary>
150+
/// Gets the filter to ignore rare words in a document. For each document, terms with
151+
/// frequency/count less than the given threshold are ignored. If this is an integer
152+
/// greater than or equal to 1, then this specifies a count (of times the term must appear
153+
/// in the document); if this is a double in [0,1), then this specifies a fraction (out of
154+
/// the document's token count).
155+
///
156+
/// Note that the parameter is only used in transform of CountVectorizerModel and does not
157+
/// affect fitting.
158+
/// </summary>
159+
/// <returns>Minimum term frequency</returns>
160+
public double GetMinTF() => (double)_jvmObject.Invoke("getMinTF");
161+
162+
/// <summary>
163+
/// Sets the filter to ignore rare words in a document. For each document, terms with
164+
/// frequency/count less than the given threshold are ignored. If this is an integer
165+
/// greater than or equal to 1, then this specifies a count (of times the term must appear
166+
/// in the document); if this is a double in [0,1), then this specifies a fraction (out of
167+
/// the document's token count).
168+
///
169+
/// Note that the parameter is only used in transform of CountVectorizerModel and does not
170+
/// affect fitting.
171+
/// </summary>
172+
/// <param name="value">Minimum term frequency</param>
173+
/// <returns>New <see cref="CountVectorizer"/> with the min term frequency set</returns>
174+
public CountVectorizer SetMinTF(double value) =>
175+
WrapAsCountVectorizer((JvmObjectReference)_jvmObject.Invoke("setMinTF", value));
176+
177+
/// <summary>
178+
/// Gets the max size of the vocabulary. <see cref="CountVectorizer"/> will build a
179+
/// vocabulary that only considers the top vocabSize terms ordered by term frequency across
180+
/// the corpus.
181+
/// </summary>
182+
/// <returns>The max size of the vocabulary of type int.</returns>
183+
public int GetVocabSize() => (int)_jvmObject.Invoke("getVocabSize");
184+
185+
/// <summary>
186+
/// Sets the max size of the vocabulary. <see cref="CountVectorizer"/> will build a
187+
/// vocabulary that only considers the top vocabSize terms ordered by term frequency across
188+
/// the corpus.
189+
/// </summary>
190+
/// <param name="value">The max vocabulary size</param>
191+
/// <returns><see cref="CountVectorizer"/> with the max vocab value set</returns>
192+
public CountVectorizer SetVocabSize(int value) =>
193+
WrapAsCountVectorizer(_jvmObject.Invoke("setVocabSize", value));
194+
195+
private static CountVectorizer WrapAsCountVectorizer(object obj) =>
196+
new CountVectorizer((JvmObjectReference)obj);
197+
}
198+
}

0 commit comments

Comments
 (0)