14
14
using Microsoft . ML . Runtime . Learners ;
15
15
using Microsoft . ML . Runtime . Numeric ;
16
16
using Microsoft . ML . Runtime . Training ;
17
+ using System ;
17
18
18
19
[ assembly: LoadableClass ( AveragedPerceptronTrainer . Summary , typeof ( AveragedPerceptronTrainer ) , typeof ( AveragedPerceptronTrainer . Arguments ) ,
19
20
new [ ] { typeof ( SignatureBinaryClassifierTrainer ) , typeof ( SignatureTrainer ) , typeof ( SignatureFeatureScorerTrainer ) } ,
@@ -59,9 +60,9 @@ public AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
59
60
60
61
_outputColumns = new [ ]
61
62
{
62
- new SchemaShape . Column ( DefaultColumnNames . Score , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false ) ,
63
- new SchemaShape . Column ( DefaultColumnNames . Probability , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false ) ,
64
- new SchemaShape . Column ( DefaultColumnNames . PredictedLabel , SchemaShape . Column . VectorKind . Scalar , BoolType . Instance , false )
63
+ new SchemaShape . Column ( DefaultColumnNames . Score , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( ) ) ) ,
64
+ new SchemaShape . Column ( DefaultColumnNames . Probability , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( true ) ) ) ,
65
+ new SchemaShape . Column ( DefaultColumnNames . PredictedLabel , SchemaShape . Column . VectorKind . Scalar , BoolType . Instance , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( ) ) )
65
66
} ;
66
67
}
67
68
@@ -79,6 +80,20 @@ protected override void CheckLabel(RoleMappedData data)
79
80
data . CheckBinaryLabel ( ) ;
80
81
}
81
82
83
+ protected override void CheckLabelCompatible ( SchemaShape . Column labelCol )
84
+ {
85
+ Contracts . AssertValue ( labelCol ) ;
86
+
87
+ Action error =
88
+ ( ) => throw Host . ExceptSchemaMismatch ( nameof ( labelCol ) , RoleMappedSchema . ColumnRole . Label . Value , labelCol . Name , "BL, R8, R4 or a Key" , labelCol . GetTypeString ( ) ) ;
89
+
90
+ if ( labelCol . Kind != SchemaShape . Column . VectorKind . Scalar )
91
+ error ( ) ;
92
+
93
+ if ( ! labelCol . IsKey && labelCol . ItemType != NumberType . R4 && labelCol . ItemType != NumberType . R8 && ! labelCol . ItemType . IsBool )
94
+ error ( ) ;
95
+ }
96
+
82
97
private static SchemaShape . Column MakeLabelColumn ( string labelColumn )
83
98
{
84
99
return new SchemaShape . Column ( labelColumn , SchemaShape . Column . VectorKind . Scalar , BoolType . Instance , false ) ;
0 commit comments