Skip to content

Commit bebf523

Browse files
authored
Ova and Pkpd as estimators (#865)
1 parent 66de9a4 commit bebf523

File tree

11 files changed

+538
-192
lines changed

11 files changed

+538
-192
lines changed

src/Microsoft.ML.Data/Prediction/Calibrator.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1122,8 +1122,8 @@ public sealed class PlattCalibratorTrainer : CalibratorTrainerBase
11221122
private Double _paramA;
11231123
private Double _paramB;
11241124

1125-
public const string UserName = "Sigmoid Calibration";
1126-
public const string LoadName = "PlattCalibration";
1125+
internal const string UserName = "Sigmoid Calibration";
1126+
internal const string LoadName = "PlattCalibration";
11271127
internal const string Summary = "This model was introduced by Platt in the paper Probabilistic Outputs for Support Vector Machines "
11281128
+ "and Comparisons to Regularized Likelihood Methods";
11291129

src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public Arguments()
5757
env => new Ova(env, new Ova.Arguments()
5858
{
5959
PredictorType = ComponentFactoryUtils.CreateFromFunction(
60-
e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
60+
e => new AveragedPerceptronTrainer(e, new AveragedPerceptronTrainer.Arguments()))
6161
}));
6262
}
6363
}

src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs

+136-27
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,87 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
5+
using Microsoft.ML.Core.Data;
76
using Microsoft.ML.Runtime.CommandLine;
87
using Microsoft.ML.Runtime.Data;
98
using Microsoft.ML.Runtime.Data.Conversion;
109
using Microsoft.ML.Runtime.EntryPoints;
1110
using Microsoft.ML.Runtime.Internal.Calibration;
1211
using Microsoft.ML.Runtime.Internal.Internallearn;
1312
using Microsoft.ML.Runtime.Training;
13+
using System.Collections.Generic;
14+
using System.Linq;
1415

1516
namespace Microsoft.ML.Runtime.Learners
1617
{
17-
using TScalarTrainer = ITrainer<IPredictorProducing<Float>>;
18+
using TScalarTrainer = ITrainerEstimator<IPredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
1819

19-
public abstract class MetaMulticlassTrainer<TPred, TArgs> : TrainerBase<TPred>
20-
where TPred : IPredictor
21-
where TArgs : MetaMulticlassTrainer<TPred, TArgs>.ArgumentsBase
20+
public abstract class MetaMulticlassTrainer<TTransformer, TModel> : ITrainerEstimator<TTransformer, TModel>, ITrainer<TModel>
21+
where TTransformer : IPredictionTransformer<TModel>
22+
where TModel : IPredictor
2223
{
2324
public abstract class ArgumentsBase
2425
{
25-
[Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 1, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
26+
[Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 4, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
2627
[TGUI(Label = "Predictor Type", Description = "Type of underlying binary predictor")]
2728
public IComponentFactory<TScalarTrainer> PredictorType;
2829

29-
[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "<None>", SignatureType = typeof(SignatureCalibrator))]
30+
[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", SortOrder = 150, NullName = "<None>", SignatureType = typeof(SignatureCalibrator))]
3031
public IComponentFactory<ICalibratorTrainer> Calibrator = new PlattCalibratorTrainerFactory();
3132

32-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")]
33+
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", SortOrder = 150, ShortName = "numcali")]
3334
public int MaxCalibrationExamples = 1000000000;
3435

35-
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", ShortName = "missNeg")]
36+
[Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", SortOrder = 150, ShortName = "missNeg")]
3637
public bool ImputeMissingLabelsAsNegative;
3738
}
3839

39-
protected readonly TArgs Args;
40+
/// <summary>
41+
/// The label column that the trainer expects.
42+
/// </summary>
43+
public readonly SchemaShape.Column LabelColumn;
44+
45+
protected readonly ArgumentsBase Args;
46+
protected readonly IHost Host;
47+
protected readonly ICalibratorTrainer Calibrator;
48+
4049
private TScalarTrainer _trainer;
4150

42-
public sealed override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
43-
public override TrainerInfo Info { get; }
51+
public PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
52+
53+
protected SchemaShape.Column[] OutputColumns;
4454

45-
internal MetaMulticlassTrainer(IHostEnvironment env, TArgs args, string name)
46-
: base(env, name)
55+
public TrainerInfo Info { get; }
56+
57+
public TScalarTrainer PredictorType;
58+
59+
/// <summary>
60+
/// Initializes the <see cref="MetaMulticlassTrainer{TTransformer, TModel}"/> from the Arguments class.
61+
/// </summary>
62+
/// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param>
63+
/// <param name="args">The legacy arguments <see cref="ArgumentsBase"/>class.</param>
64+
/// <param name="name">The component name.</param>
65+
/// <param name="labelColumn">The label column for the metalinear trainer and the binary trainer.</param>
66+
/// <param name="singleEstimator">The binary estimator.</param>
67+
/// <param name="calibrator">The calibrator. If a calibrator is not explicitly provided, it will default to <see cref="PlattCalibratorCalibratorTrainer"/></param>
68+
internal MetaMulticlassTrainer(IHostEnvironment env, ArgumentsBase args, string name, string labelColumn = null,
69+
TScalarTrainer singleEstimator = null, ICalibratorTrainer calibrator = null)
4770
{
71+
Host = Contracts.CheckRef(env, nameof(env)).Register(name);
4872
Host.CheckValue(args, nameof(args));
4973
Args = args;
74+
75+
if (labelColumn != null)
76+
LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
77+
5078
// Create the first trainer so errors in the args surface early.
51-
_trainer = CreateTrainer();
79+
_trainer = singleEstimator ?? CreateTrainer();
80+
81+
Calibrator = calibrator ?? new PlattCalibratorTrainer(env);
82+
83+
if (args.Calibrator != null)
84+
Calibrator = args.Calibrator.CreateComponent(Host);
85+
5286
// Regarding caching, no matter what the internal predictor, we're performing many passes
5387
// simply by virtue of this being a meta-trainer, so we will still cache.
5488
Info = new TrainerInfo(normalization: _trainer.Info.NeedNormalization);
@@ -61,29 +95,28 @@ private TScalarTrainer CreateTrainer()
6195
new LinearSvm(Host, new LinearSvm.Arguments());
6296
}
6397

64-
protected IDataView MapLabelsCore<T>(ColumnType type, RefPredicate<T> equalsTarget, RoleMappedData data, string dstName)
98+
protected IDataView MapLabelsCore<T>(ColumnType type, RefPredicate<T> equalsTarget, RoleMappedData data)
6599
{
66100
Host.AssertValue(type);
67101
Host.Assert(type.RawType == typeof(T));
68102
Host.AssertValue(equalsTarget);
69103
Host.AssertValue(data);
70104
Host.AssertValue(data.Schema.Label);
71-
Host.AssertNonWhiteSpace(dstName);
72105

73106
var lab = data.Schema.Label;
74107

75108
RefPredicate<T> isMissing;
76109
if (!Args.ImputeMissingLabelsAsNegative && Conversions.Instance.TryGetIsNAPredicate(type, out isMissing))
77110
{
78111
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
79-
lab.Name, dstName, type, NumberType.Float,
80-
(ref T src, ref Float dst) =>
81-
dst = equalsTarget(ref src) ? 1 : (isMissing(ref src) ? Float.NaN : default(Float)));
112+
lab.Name, lab.Name, type, NumberType.Float,
113+
(ref T src, ref float dst) =>
114+
dst = equalsTarget(ref src) ? 1 : (isMissing(ref src) ? float.NaN : default(float)));
82115
}
83116
return LambdaColumnMapper.Create(Host, "Label mapper", data.Data,
84-
lab.Name, dstName, type, NumberType.Float,
85-
(ref T src, ref Float dst) =>
86-
dst = equalsTarget(ref src) ? 1 : default(Float));
117+
lab.Name, lab.Name, type, NumberType.Float,
118+
(ref T src, ref float dst) =>
119+
dst = equalsTarget(ref src) ? 1 : default(float));
87120
}
88121

89122
protected TScalarTrainer GetTrainer()
@@ -95,9 +128,14 @@ protected TScalarTrainer GetTrainer()
95128
return train;
96129
}
97130

98-
protected abstract TPred TrainCore(IChannel ch, RoleMappedData data, int count);
131+
protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, int count);
99132

100-
public override TPred Train(TrainContext context)
133+
/// <summary>
134+
/// The legacy train method.
135+
/// </summary>
136+
/// <param name="context">The trainig context for this learner.</param>
137+
/// <returns>The trained model.</returns>
138+
public TModel Train(TrainContext context)
101139
{
102140
Host.CheckValue(context, nameof(context));
103141
var data = context.TrainingSet;
@@ -116,5 +154,76 @@ public override TPred Train(TrainContext context)
116154
return pred;
117155
}
118156
}
157+
158+
/// <summary>
159+
/// Gets the output columns.
160+
/// </summary>
161+
/// <param name="inputSchema">The input schema. </param>
162+
/// <returns>The output <see cref="SchemaShape"/></returns>
163+
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
164+
{
165+
Host.CheckValue(inputSchema, nameof(inputSchema));
166+
167+
if (LabelColumn != null)
168+
{
169+
if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
170+
throw Host.ExceptSchemaMismatch(nameof(labelCol), DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel);
171+
172+
if (!LabelColumn.IsCompatibleWith(labelCol))
173+
throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible");
174+
}
175+
176+
var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
177+
foreach (var col in GetOutputColumnsCore(inputSchema))
178+
outColumns[col.Name] = col;
179+
180+
return new SchemaShape(outColumns.Values);
181+
}
182+
183+
private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
184+
{
185+
if (LabelColumn != null)
186+
{
187+
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
188+
Contracts.Assert(success);
189+
190+
var metadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
191+
.Concat(MetadataForScoreColumn()));
192+
return new[]
193+
{
194+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataForScoreColumn())),
195+
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, metadata)
196+
};
197+
}
198+
else
199+
return new[]
200+
{
201+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(MetadataForScoreColumn())),
202+
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(MetadataForScoreColumn()))
203+
};
204+
}
205+
206+
/// <summary>
207+
/// Normal metadata that we produce for score columns.
208+
/// </summary>
209+
private static IEnumerable<SchemaShape.Column> MetadataForScoreColumn()
210+
{
211+
var cols = new List<SchemaShape.Column>();
212+
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true));
213+
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar, TextType.Instance, false));
214+
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false));
215+
cols.Add(new SchemaShape.Column(MetadataUtils.Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar, TextType.Instance, false));
216+
217+
return cols;
218+
}
219+
220+
IPredictor ITrainer.Train(TrainContext context) => Train(context);
221+
222+
/// <summary>
223+
/// Fits the data to the trainer.
224+
/// </summary>
225+
/// <param name="input">The input data to fit to.</param>
226+
/// <returns>The transformer.</returns>
227+
public abstract TTransformer Fit(IDataView input);
119228
}
120-
}
229+
}

0 commit comments

Comments
 (0)