Skip to content

Commit a901048

Browse files
authored
Fix for trainer estimator metadata propagation (#909)
1 parent bebf523 commit a901048

File tree

10 files changed

+209
-95
lines changed

10 files changed

+209
-95
lines changed

src/Microsoft.ML.Core/Data/MetadataUtils.cs

+15
Original file line numberDiff line numberDiff line change
@@ -466,5 +466,20 @@ public static bool TryGetCategoricalFeatureIndices(ISchema schema, int colIndex,
466466

467467
return isValid;
468468
}
469+
470+
/// <summary>
471+
/// Produces sequence of columns that are generated by trainer estimators.
472+
/// </summary>
473+
/// <param name="isNormalized">whether we should also append 'IsNormalized' (typically for probability column)</param>
474+
public static IEnumerable<SchemaShape.Column> GetTrainerOutputMetadata(bool isNormalized = false)
475+
{
476+
var cols = new List<SchemaShape.Column>();
477+
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true));
478+
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextType.Instance, false));
479+
cols.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextType.Instance, false));
480+
if (isNormalized)
481+
cols.Add(new SchemaShape.Column(Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false));
482+
return cols;
483+
}
469484
}
470485
}

src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.Collections.Generic;
56
using System.Linq;
67
using Microsoft.ML.Core.Data;
78
using Microsoft.ML.Runtime.Data;

src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs

+81-83
Large diffs are not rendered by default.

src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs

+18-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
using Microsoft.ML.Runtime.Learners;
1515
using Microsoft.ML.Runtime.Numeric;
1616
using Microsoft.ML.Runtime.Training;
17+
using System;
1718

1819
[assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Arguments),
1920
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
@@ -59,9 +60,9 @@ public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
5960

6061
_outputColumns = new[]
6162
{
62-
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
63-
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
64-
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
63+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
64+
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))),
65+
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
6566
};
6667
}
6768

@@ -79,6 +80,20 @@ protected override void CheckLabel(RoleMappedData data)
7980
data.CheckBinaryLabel();
8081
}
8182

83+
protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
84+
{
85+
Contracts.AssertValue(labelCol);
86+
87+
Action error =
88+
() => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "BL, R8, R4 or a Key", labelCol.GetTypeString());
89+
90+
if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
91+
error();
92+
93+
if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8 && !labelCol.ItemType.IsBool)
94+
error();
95+
}
96+
8297
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
8398
{
8499
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);

src/Microsoft.ML.StandardLearners/Standard/Online/OnlineGradientDescent.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public OnlineGradientDescentTrainer(IHostEnvironment env, Arguments args)
5858

5959
_outputColumns = new[]
6060
{
61-
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false)
61+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
6262
};
6363
}
6464

src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,12 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
6161
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
6262
Contracts.Assert(success);
6363

64-
var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues));
64+
var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
65+
.Concat(MetadataUtils.GetTrainerOutputMetadata()));
6566
return new[]
6667
{
67-
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false),
68-
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, labelCol.ItemType, labelCol.IsKey, metadata)
68+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
69+
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, metadata)
6970
};
7071
}
7172

src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featur
6363
_args = args;
6464
_outputColumns = new[]
6565
{
66-
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false)
66+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
6767
};
6868
}
6969

test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs

+5-4
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,16 @@ protected void TestEstimatorCore(IEstimator<ITransformer> estimator,
153153
CheckSameSchemaShape(outSchemaShape, scoredTrainSchemaShape);
154154
}
155155

156-
private void CheckSameSchemaShape(SchemaShape first, SchemaShape second)
156+
private void CheckSameSchemaShape(SchemaShape promised, SchemaShape delivered)
157157
{
158-
Assert.True(first.Columns.Length == second.Columns.Length);
159-
var sortedCols1 = first.Columns.OrderBy(x => x.Name);
160-
var sortedCols2 = second.Columns.OrderBy(x => x.Name);
158+
Assert.True(promised.Columns.Length == delivered.Columns.Length);
159+
var sortedCols1 = promised.Columns.OrderBy(x => x.Name);
160+
var sortedCols2 = delivered.Columns.OrderBy(x => x.Name);
161161

162162
foreach (var (x, y) in sortedCols1.Zip(sortedCols2, (x, y) => (x, y)))
163163
{
164164
Assert.Equal(x.Name, y.Name);
165+
// We want the 'promised' metadata to be a superset of 'delivered'.
165166
Assert.True(y.IsCompatibleWith(x), $"Mismatch on {x.Name}");
166167
}
167168
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.Learners;
8+
using Microsoft.ML.Runtime.RunTests;
9+
using Xunit;
10+
using Xunit.Abstractions;
11+
12+
namespace Microsoft.ML.Tests.Transformers
13+
{
14+
public sealed class OnlineLinearTests : TestDataPipeBase
15+
{
16+
public OnlineLinearTests(ITestOutputHelper helper) : base(helper)
17+
{
18+
}
19+
20+
[Fact(Skip = "AP is now uncalibrated but advertises as calibrated")]
21+
public void OnlineLinearWorkout()
22+
{
23+
var dataPath = GetDataPath("breast-cancer.txt");
24+
25+
var data = TextLoader.CreateReader(Env, ctx => (Label: ctx.LoadFloat(0), Features: ctx.LoadFloat(1, 10)))
26+
.Read(new MultiFileSource(dataPath));
27+
28+
var pipe = data.MakeNewEstimator()
29+
.Append(r => (r.Label, Features: r.Features.Normalize()));
30+
31+
var trainData = pipe.Fit(data).Transform(data).AsDynamic;
32+
33+
IEstimator<ITransformer> est = new OnlineGradientDescentTrainer(Env, new OnlineGradientDescentTrainer.Arguments());
34+
TestEstimatorCore(est, trainData);
35+
36+
est = new AveragedPerceptronTrainer(Env, new AveragedPerceptronTrainer.Arguments());
37+
TestEstimatorCore(est, trainData);
38+
39+
Done();
40+
41+
}
42+
}
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.Learners;
8+
using Microsoft.ML.Runtime.RunTests;
9+
using Xunit;
10+
using Xunit.Abstractions;
11+
12+
namespace Microsoft.ML.Tests.Transformers
13+
{
14+
public sealed class SdcaTests : TestDataPipeBase
15+
{
16+
public SdcaTests(ITestOutputHelper helper) : base(helper)
17+
{
18+
}
19+
20+
[Fact]
21+
public void SdcaWorkout()
22+
{
23+
var dataPath = GetDataPath("breast-cancer.txt");
24+
25+
var data = TextLoader.CreateReader(Env, ctx => (Label: ctx.LoadFloat(0), Features: ctx.LoadFloat(1, 10)))
26+
.Read(new MultiFileSource(dataPath));
27+
28+
IEstimator<ITransformer> est = new LinearClassificationTrainer(Env, new LinearClassificationTrainer.Arguments { ConvergenceTolerance = 1e-2f }, "Features", "Label");
29+
TestEstimatorCore(est, data.AsDynamic);
30+
31+
est = new SdcaRegressionTrainer(Env, new SdcaRegressionTrainer.Arguments { ConvergenceTolerance = 1e-2f }, "Features", "Label");
32+
TestEstimatorCore(est, data.AsDynamic);
33+
34+
est = new SdcaMultiClassTrainer(Env, new SdcaMultiClassTrainer.Arguments { ConvergenceTolerance = 1e-2f }, "Features", "Label");
35+
TestEstimatorCore(est, data.AsDynamic);
36+
37+
Done();
38+
}
39+
}
40+
}

0 commit comments

Comments
 (0)