-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmin_entropy_loss_op.cc
129 lines (105 loc) · 3.29 KB
/
min_entropy_loss_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include "min_entropy_loss_op.h"
namespace caffe2 {
namespace {}
template <>
bool MinEntropyLossOp<float, CPUContext>::RunOnDevice() {
const auto& X = Input(0);
const auto& L = Input(1);
CAFFE_ENFORCE_EQ(X.dim(), 2);
CAFFE_ENFORCE_EQ(L.dim(), 2);
CAFFE_ENFORCE_EQ(X.dim32(1), L.dim32(1));
int N = X.dim32(0);
int C = X.dim32(1);
int B = L.dim32(0);
auto* Y = Output(0);
Y->Resize(vector<int64_t>());
math::Set<float, CPUContext>(Y->numel(), 0.f, Y->mutable_data<float>(),
&context_);
const float* Xdata = X.data<float>();
const float* Ldata = L.data<float>();
auto* Ydata = Y->mutable_data<float>();
float loss = 0;
int norm = 0;
CAFFE_ENFORCE_EQ(L.dim32(0), 1);
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
if (Ldata[c] < 0.5) {
continue;
}
float prob = std::max(Xdata[n * C + c], kLOG_THRESHOLD());
loss -= (prob * log(prob));
norm += 1;
}
}
Ydata[0] = loss / norm;
return true;
}
template <>
bool MinEntropyLossGradientOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto& L = Input(1);
auto& dY = Input(2);
CAFFE_ENFORCE_EQ(X.dim(), 2);
CAFFE_ENFORCE_EQ(L.dim(), 2);
CAFFE_ENFORCE_EQ(X.dim32(1), L.dim32(1));
CAFFE_ENFORCE_EQ(dY.numel(), 1);
int N = X.dim32(0);
int C = X.dim32(1);
int B = L.dim32(0);
auto* dX = Output(0);
dX->ResizeLike(X);
math::Set<float, CPUContext>(dX->numel(), 0.f, dX->mutable_data<float>(),
&context_);
const float* Xdata = X.data<float>();
const float* Ldata = L.data<float>();
const float* dYdata = dY.data<float>();
float* dXdata = dX->mutable_data<float>();
int norm = 0;
CAFFE_ENFORCE_EQ(L.dim32(0), 1);
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
if (Ldata[c] < 0.5) {
continue;
}
norm += 1;
}
}
const float scale = dYdata[0] / norm;
CAFFE_ENFORCE_EQ(L.dim32(0), 1);
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
if (Ldata[c] < 0.5) {
continue;
}
float prob = std::max(Xdata[n * C + c], kLOG_THRESHOLD());
dXdata[n * C + c] =
std::min(scale * (-1 + (-1) * float(log(prob))), kDIFF_THRESHOLD());
}
}
return true;
}
REGISTER_CPU_OPERATOR(MinEntropyLoss, MinEntropyLossOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(MinEntropyLossGradient,
MinEntropyLossGradientOp<float, CPUContext>);
namespace {} // namespace
using namespace std::placeholders;
OPERATOR_SCHEMA(MinEntropyLoss)
.NumInputs(2, 2 + 1)
.NumOutputs(1)
.IdenticalTypeAndShapeOfInputDim(0, 0)
.SetDoc(R"DOC(
)DOC")
.Input(0, "X", "Input blob of size N x C")
.Input(1, "L", "Input Blob of size B x C")
.Output(0, "Y", "Output blob after the cross entropy computation");
OPERATOR_SCHEMA(MinEntropyLossGradient).NumInputs(3).NumOutputs(1);
class GetMinEntropyLossGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef("MinEntropyLossGradient", "",
vector<string>{I(0), I(1), GO(0)},
vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(MinEntropyLoss, GetMinEntropyLossGradient);
} // namespace caffe2