1 #include "caffe2/operators/jsd_op.h" 7 static constexpr
float kLOG_THRESHOLD() {
11 inline float logit(
float p) {
14 float x = std::min(std::max(p, kLOG_THRESHOLD()), 1 - kLOG_THRESHOLD());
15 return -log(1. / x - 1.);
18 inline float entropy(
float p) {
19 if (p < kLOG_THRESHOLD() || 1 - p < kLOG_THRESHOLD()) {
23 return -p * log(p) - q * log(q);
29 bool BernoulliJSDOp<float, CPUContext>::RunOnDevice() {
34 CAFFE_ENFORCE_EQ(T.size(), N);
36 auto* x_data = X.data<
float>();
37 auto* t_data = T.data<
float>();
38 auto* l_data = L->mutable_data<
float>();
39 for (
int i = 0; i < N; i++) {
40 auto p_mdl = x_data[i];
41 auto p_emp = t_data[i];
42 auto p_avg = (p_mdl + p_emp) / 2.;
43 auto jsd = entropy(p_avg) - (entropy(p_mdl) + entropy(p_emp)) / 2.;
50 bool BernoulliJSDGradientOp<float, CPUContext>::RunOnDevice() {
57 auto* go_data = go.data<
float>();
58 auto* x_data = X.data<
float>();
59 auto* t_data = T.data<
float>();
60 auto* gi_data = gi->mutable_data<
float>();
61 for (
int i = 0; i < N; i++) {
62 auto p_mdl = x_data[i];
63 auto p_emp = t_data[i];
64 auto p_avg = (p_mdl + p_emp) / 2.;
65 auto g_jsd = (logit(p_mdl) - logit(p_avg)) / 2.;
66 gi_data[i] = go_data[i] * g_jsd;
70 REGISTER_CPU_OPERATOR(BernoulliJSD, BernoulliJSDOp<float, CPUContext>);
71 REGISTER_CPU_OPERATOR(
73 BernoulliJSDGradientOp<float, CPUContext>);
74 OPERATOR_SCHEMA(BernoulliJSD)
78 Computes the Jensen-Shannon divergence (JSD) between two Bernoulli distributions 79 where each is parametrized by a single probability. 81 .Input(0, "X",
"array of probabilities for prediction")
82 .Input(0,
"T",
"array of probabilities for target")
83 .Output(0,
"L",
"array of JSD losses");
84 OPERATOR_SCHEMA(BernoulliJSDGradient).NumInputs(3).NumOutputs(1);
87 using GradientMakerBase::GradientMakerBase;
88 vector<OperatorDef> GetGradientDefs()
override {
90 "BernoulliJSDGradient",
92 vector<string>{GO(0), I(0), I(1)},
93 vector<string>{GI(0)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...