2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
- using Float = System . Single ;
6
-
5
+ using Microsoft . ML . Core . Data ;
7
6
using Microsoft . ML . Runtime . CommandLine ;
8
7
using Microsoft . ML . Runtime . Data ;
9
8
using Microsoft . ML . Runtime . Data . Conversion ;
10
9
using Microsoft . ML . Runtime . EntryPoints ;
11
10
using Microsoft . ML . Runtime . Internal . Calibration ;
12
11
using Microsoft . ML . Runtime . Internal . Internallearn ;
13
12
using Microsoft . ML . Runtime . Training ;
13
+ using System . Collections . Generic ;
14
+ using System . Linq ;
14
15
15
16
namespace Microsoft . ML . Runtime . Learners
16
17
{
17
- using TScalarTrainer = ITrainer < IPredictorProducing < Float > > ;
18
+ using TScalarTrainer = ITrainerEstimator < IPredictionTransformer < IPredictorProducing < float > > , IPredictorProducing < float > > ;
18
19
19
- public abstract class MetaMulticlassTrainer < TPred , TArgs > : TrainerBase < TPred >
20
- where TPred : IPredictor
21
- where TArgs : MetaMulticlassTrainer < TPred , TArgs > . ArgumentsBase
20
+ public abstract class MetaMulticlassTrainer < TTransformer , TModel > : ITrainerEstimator < TTransformer , TModel > , ITrainer < TModel >
21
+ where TTransformer : IPredictionTransformer < TModel >
22
+ where TModel : IPredictor
22
23
{
23
24
public abstract class ArgumentsBase
24
25
{
25
- [ Argument ( ArgumentType . Multiple , HelpText = "Base predictor" , ShortName = "p" , SortOrder = 1 , SignatureType = typeof ( SignatureBinaryClassifierTrainer ) ) ]
26
+ [ Argument ( ArgumentType . Multiple , HelpText = "Base predictor" , ShortName = "p" , SortOrder = 4 , SignatureType = typeof ( SignatureBinaryClassifierTrainer ) ) ]
26
27
[ TGUI ( Label = "Predictor Type" , Description = "Type of underlying binary predictor" ) ]
27
28
public IComponentFactory < TScalarTrainer > PredictorType ;
28
29
29
- [ Argument ( ArgumentType . Multiple , HelpText = "Output calibrator" , ShortName = "cali" , NullName = "<None>" , SignatureType = typeof ( SignatureCalibrator ) ) ]
30
+ [ Argument ( ArgumentType . Multiple , HelpText = "Output calibrator" , ShortName = "cali" , SortOrder = 150 , NullName = "<None>" , SignatureType = typeof ( SignatureCalibrator ) ) ]
30
31
public IComponentFactory < ICalibratorTrainer > Calibrator = new PlattCalibratorTrainerFactory ( ) ;
31
32
32
- [ Argument ( ArgumentType . LastOccurenceWins , HelpText = "Number of instances to train the calibrator" , ShortName = "numcali" ) ]
33
+ [ Argument ( ArgumentType . LastOccurenceWins , HelpText = "Number of instances to train the calibrator" , SortOrder = 150 , ShortName = "numcali" ) ]
33
34
public int MaxCalibrationExamples = 1000000000 ;
34
35
35
- [ Argument ( ArgumentType . Multiple , HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing" , ShortName = "missNeg" ) ]
36
+ [ Argument ( ArgumentType . Multiple , HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing" , SortOrder = 150 , ShortName = "missNeg" ) ]
36
37
public bool ImputeMissingLabelsAsNegative ;
37
38
}
38
39
39
- protected readonly TArgs Args ;
40
+ /// <summary>
41
+ /// The label column that the trainer expects.
42
+ /// </summary>
43
+ public readonly SchemaShape . Column LabelColumn ;
44
+
45
+ protected readonly ArgumentsBase Args ;
46
+ protected readonly IHost Host ;
47
+ protected readonly ICalibratorTrainer Calibrator ;
48
+
40
49
private TScalarTrainer _trainer ;
41
50
42
- public sealed override PredictionKind PredictionKind => PredictionKind . MultiClassClassification ;
43
- public override TrainerInfo Info { get ; }
51
+ public PredictionKind PredictionKind => PredictionKind . MultiClassClassification ;
52
+
53
+ protected SchemaShape . Column [ ] OutputColumns ;
44
54
45
- internal MetaMulticlassTrainer ( IHostEnvironment env , TArgs args , string name )
46
- : base ( env , name )
55
+ public TrainerInfo Info { get ; }
56
+
57
+ public TScalarTrainer PredictorType ;
58
+
59
+ /// <summary>
60
+ /// Initializes the <see cref="MetaMulticlassTrainer{TTransformer, TModel}"/> from the Arguments class.
61
+ /// </summary>
62
+ /// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param>
63
+ /// <param name="args">The legacy arguments <see cref="ArgumentsBase"/>class.</param>
64
+ /// <param name="name">The component name.</param>
65
+ /// <param name="labelColumn">The label column for the metalinear trainer and the binary trainer.</param>
66
+ /// <param name="singleEstimator">The binary estimator.</param>
67
+ /// <param name="calibrator">The calibrator. If a calibrator is not explicitly provided, it will default to <see cref="PlattCalibratorCalibratorTrainer"/></param>
68
+ internal MetaMulticlassTrainer ( IHostEnvironment env , ArgumentsBase args , string name , string labelColumn = null ,
69
+ TScalarTrainer singleEstimator = null , ICalibratorTrainer calibrator = null )
47
70
{
71
+ Host = Contracts . CheckRef ( env , nameof ( env ) ) . Register ( name ) ;
48
72
Host . CheckValue ( args , nameof ( args ) ) ;
49
73
Args = args ;
74
+
75
+ if ( labelColumn != null )
76
+ LabelColumn = new SchemaShape . Column ( labelColumn , SchemaShape . Column . VectorKind . Scalar , NumberType . U4 , true ) ;
77
+
50
78
// Create the first trainer so errors in the args surface early.
51
- _trainer = CreateTrainer ( ) ;
79
+ _trainer = singleEstimator ?? CreateTrainer ( ) ;
80
+
81
+ Calibrator = calibrator ?? new PlattCalibratorTrainer ( env ) ;
82
+
83
+ if ( args . Calibrator != null )
84
+ Calibrator = args . Calibrator . CreateComponent ( Host ) ;
85
+
52
86
// Regarding caching, no matter what the internal predictor, we're performing many passes
53
87
// simply by virtue of this being a meta-trainer, so we will still cache.
54
88
Info = new TrainerInfo ( normalization : _trainer . Info . NeedNormalization ) ;
@@ -61,29 +95,28 @@ private TScalarTrainer CreateTrainer()
61
95
new LinearSvm ( Host , new LinearSvm . Arguments ( ) ) ;
62
96
}
63
97
64
- protected IDataView MapLabelsCore < T > ( ColumnType type , RefPredicate < T > equalsTarget , RoleMappedData data , string dstName )
98
+ protected IDataView MapLabelsCore < T > ( ColumnType type , RefPredicate < T > equalsTarget , RoleMappedData data )
65
99
{
66
100
Host . AssertValue ( type ) ;
67
101
Host . Assert ( type . RawType == typeof ( T ) ) ;
68
102
Host . AssertValue ( equalsTarget ) ;
69
103
Host . AssertValue ( data ) ;
70
104
Host . AssertValue ( data . Schema . Label ) ;
71
- Host . AssertNonWhiteSpace ( dstName ) ;
72
105
73
106
var lab = data . Schema . Label ;
74
107
75
108
RefPredicate < T > isMissing ;
76
109
if ( ! Args . ImputeMissingLabelsAsNegative && Conversions . Instance . TryGetIsNAPredicate ( type , out isMissing ) )
77
110
{
78
111
return LambdaColumnMapper . Create ( Host , "Label mapper" , data . Data ,
79
- lab . Name , dstName , type , NumberType . Float ,
80
- ( ref T src , ref Float dst ) =>
81
- dst = equalsTarget ( ref src ) ? 1 : ( isMissing ( ref src ) ? Float . NaN : default ( Float ) ) ) ;
112
+ lab . Name , lab . Name , type , NumberType . Float ,
113
+ ( ref T src , ref float dst ) =>
114
+ dst = equalsTarget ( ref src ) ? 1 : ( isMissing ( ref src ) ? float . NaN : default ( float ) ) ) ;
82
115
}
83
116
return LambdaColumnMapper . Create ( Host , "Label mapper" , data . Data ,
84
- lab . Name , dstName , type , NumberType . Float ,
85
- ( ref T src , ref Float dst ) =>
86
- dst = equalsTarget ( ref src ) ? 1 : default ( Float ) ) ;
117
+ lab . Name , lab . Name , type , NumberType . Float ,
118
+ ( ref T src , ref float dst ) =>
119
+ dst = equalsTarget ( ref src ) ? 1 : default ( float ) ) ;
87
120
}
88
121
89
122
protected TScalarTrainer GetTrainer ( )
@@ -95,9 +128,14 @@ protected TScalarTrainer GetTrainer()
95
128
return train ;
96
129
}
97
130
98
- protected abstract TPred TrainCore ( IChannel ch , RoleMappedData data , int count ) ;
131
+ protected abstract TModel TrainCore ( IChannel ch , RoleMappedData data , int count ) ;
99
132
100
- public override TPred Train ( TrainContext context )
133
+ /// <summary>
134
+ /// The legacy train method.
135
+ /// </summary>
136
+ /// <param name="context">The trainig context for this learner.</param>
137
+ /// <returns>The trained model.</returns>
138
+ public TModel Train ( TrainContext context )
101
139
{
102
140
Host . CheckValue ( context , nameof ( context ) ) ;
103
141
var data = context . TrainingSet ;
@@ -116,5 +154,76 @@ public override TPred Train(TrainContext context)
116
154
return pred ;
117
155
}
118
156
}
157
+
158
+ /// <summary>
159
+ /// Gets the output columns.
160
+ /// </summary>
161
+ /// <param name="inputSchema">The input schema. </param>
162
+ /// <returns>The output <see cref="SchemaShape"/></returns>
163
+ public SchemaShape GetOutputSchema ( SchemaShape inputSchema )
164
+ {
165
+ Host . CheckValue ( inputSchema , nameof ( inputSchema ) ) ;
166
+
167
+ if ( LabelColumn != null )
168
+ {
169
+ if ( ! inputSchema . TryFindColumn ( LabelColumn . Name , out var labelCol ) )
170
+ throw Host . ExceptSchemaMismatch ( nameof ( labelCol ) , DefaultColumnNames . PredictedLabel , DefaultColumnNames . PredictedLabel ) ;
171
+
172
+ if ( ! LabelColumn . IsCompatibleWith ( labelCol ) )
173
+ throw Host . Except ( $ "Label column '{ LabelColumn . Name } ' is not compatible") ;
174
+ }
175
+
176
+ var outColumns = inputSchema . Columns . ToDictionary ( x => x . Name ) ;
177
+ foreach ( var col in GetOutputColumnsCore ( inputSchema ) )
178
+ outColumns [ col . Name ] = col ;
179
+
180
+ return new SchemaShape ( outColumns . Values ) ;
181
+ }
182
+
183
+ private SchemaShape . Column [ ] GetOutputColumnsCore ( SchemaShape inputSchema )
184
+ {
185
+ if ( LabelColumn != null )
186
+ {
187
+ bool success = inputSchema . TryFindColumn ( LabelColumn . Name , out var labelCol ) ;
188
+ Contracts . Assert ( success ) ;
189
+
190
+ var metadata = new SchemaShape ( labelCol . Metadata . Columns . Where ( x => x . Name == MetadataUtils . Kinds . KeyValues )
191
+ . Concat ( MetadataForScoreColumn ( ) ) ) ;
192
+ return new [ ]
193
+ {
194
+ new SchemaShape . Column ( DefaultColumnNames . Score , SchemaShape . Column . VectorKind . Vector , NumberType . R4 , false , new SchemaShape ( MetadataForScoreColumn ( ) ) ) ,
195
+ new SchemaShape . Column ( DefaultColumnNames . PredictedLabel , SchemaShape . Column . VectorKind . Scalar , NumberType . U4 , true , metadata )
196
+ } ;
197
+ }
198
+ else
199
+ return new [ ]
200
+ {
201
+ new SchemaShape . Column ( DefaultColumnNames . Score , SchemaShape . Column . VectorKind . Vector , NumberType . R4 , false , new SchemaShape ( MetadataForScoreColumn ( ) ) ) ,
202
+ new SchemaShape . Column ( DefaultColumnNames . PredictedLabel , SchemaShape . Column . VectorKind . Scalar , NumberType . U4 , true , new SchemaShape ( MetadataForScoreColumn ( ) ) )
203
+ } ;
204
+ }
205
+
206
+ /// <summary>
207
+ /// Normal metadata that we produce for score columns.
208
+ /// </summary>
209
+ private static IEnumerable < SchemaShape . Column > MetadataForScoreColumn ( )
210
+ {
211
+ var cols = new List < SchemaShape . Column > ( ) ;
212
+ cols . Add ( new SchemaShape . Column ( MetadataUtils . Kinds . ScoreColumnSetId , SchemaShape . Column . VectorKind . Scalar , NumberType . U4 , true ) ) ;
213
+ cols . Add ( new SchemaShape . Column ( MetadataUtils . Kinds . ScoreColumnKind , SchemaShape . Column . VectorKind . Scalar , TextType . Instance , false ) ) ;
214
+ cols . Add ( new SchemaShape . Column ( MetadataUtils . Kinds . SlotNames , SchemaShape . Column . VectorKind . Vector , TextType . Instance , false ) ) ;
215
+ cols . Add ( new SchemaShape . Column ( MetadataUtils . Kinds . ScoreValueKind , SchemaShape . Column . VectorKind . Scalar , TextType . Instance , false ) ) ;
216
+
217
+ return cols ;
218
+ }
219
+
220
+ IPredictor ITrainer . Train ( TrainContext context ) => Train ( context ) ;
221
+
222
+ /// <summary>
223
+ /// Fits the data to the trainer.
224
+ /// </summary>
225
+ /// <param name="input">The input data to fit to.</param>
226
+ /// <returns>The transformer.</returns>
227
+ public abstract TTransformer Fit ( IDataView input ) ;
119
228
}
120
- }
229
+ }
0 commit comments