-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathkl_op.cc.bk
35 lines (29 loc) · 1.03 KB
/
kl_op.cc.bk
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include "kl_op.h"
namespace caffe2 {
namespace {}
OPERATOR_SCHEMA(KL)
.Arg("ignore_value", R"DOC(default is 0.5.)DOC")
.NumInputs(2)
.NumOutputs(2)
.IdenticalTypeAndShapeOfInputDim(0, 0)
.SetDoc(R"DOC(
)DOC")
.Input(0, "p", "matrix for each example and class.")
.Input(1, "q", "matrix, same shape as p.")
.Output(0, "divergence", "Vector with the divergence for each example.")
.Output(1, "count", "");
OPERATOR_SCHEMA(KLGradient).NumInputs(4).NumOutputs(1);
struct GetKLGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
ArgumentHelper argsHelper(def_);
auto ignore_value =
argsHelper.GetSingleArgument<float>("ignore_value", 0.5);
return SingleGradientDef(
"KLGradient", "", vector<string>{GO(0), I(0), I(1), O(1)},
vector<string>{GI(0)},
vector<Argument>{MakeArgument<float>("ignore_value", ignore_value)});
}
};
REGISTER_GRADIENT(KL, GetKLGradient);
} // namespace caffe2