Caffe2 - C++ API
A deep learning, cross platform ML framework
cross_entropy_op.h
1 #ifndef CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
2 #define CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class LabelCrossEntropyOp final : public Operator<Context> {
13  public:
14  USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyOp);
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  bool RunOnDevice() override;
17 
18  protected:
19  static constexpr T kLOG_THRESHOLD() {
20  return static_cast<T>(1e-20);
21  }
22  // Input: X, label
23  // Output: Y
24 };
25 
26 template <typename T, class Context>
27 class LabelCrossEntropyGradientOp final : public Operator<Context> {
28  public:
29  USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyGradientOp);
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  bool RunOnDevice() override;
32 
33  protected:
34  // Input: X, label, dY
35  // Ouptut: dX. There is no gradient with respect to the label.
36  static constexpr T kLOG_THRESHOLD() {
37  return static_cast<T>(1e-20);
38  }
39 };
40 
41 // Hacky: turns a vector of probabilities into a 2-column matrix with
42 // complimentary probabilities for binary classification
43 template <typename T, class Context>
44 class MakeTwoClassOp final : public Operator<Context> {
45  public:
46  USE_SIMPLE_CTOR_DTOR(MakeTwoClassOp);
47  USE_OPERATOR_CONTEXT_FUNCTIONS;
48  bool RunOnDevice() override;
49 
50  protected:
51  // Input: X
52  // Output: Y = vstack(1-X, X)
53 };
54 
55 template <typename T, class Context>
56 class MakeTwoClassGradientOp final : public Operator<Context> {
57  public:
58  USE_SIMPLE_CTOR_DTOR(MakeTwoClassGradientOp);
59  USE_OPERATOR_CONTEXT_FUNCTIONS;
60  bool RunOnDevice() override;
61 
62  protected:
63  // Input: dY
64  // Ouptut: dX
65 };
66 
67 template <typename T, class Context>
68 class SigmoidCrossEntropyWithLogitsOp final : public Operator<Context> {
69  public:
70  USE_SIMPLE_CTOR_DTOR(SigmoidCrossEntropyWithLogitsOp);
71  USE_OPERATOR_CONTEXT_FUNCTIONS;
72  bool RunOnDevice() override;
73 };
74 
75 template <typename T, class Context>
76 class SigmoidCrossEntropyWithLogitsGradientOp final : public Operator<Context> {
77  public:
78  USE_SIMPLE_CTOR_DTOR(SigmoidCrossEntropyWithLogitsGradientOp);
79  USE_OPERATOR_CONTEXT_FUNCTIONS;
80  bool RunOnDevice() override;
81 };
82 
83 template <typename T, class Context>
84 class WeightedSigmoidCrossEntropyWithLogitsOp final : public Operator<Context> {
85  public:
86  USE_SIMPLE_CTOR_DTOR(WeightedSigmoidCrossEntropyWithLogitsOp);
87  USE_OPERATOR_CONTEXT_FUNCTIONS;
88  bool RunOnDevice() override;
89 };
90 
91 template <typename T, class Context>
93  : public Operator<Context> {
94  public:
96  USE_OPERATOR_CONTEXT_FUNCTIONS;
97  bool RunOnDevice() override;
98 };
99 
100 template <typename T, class Context>
101 class CrossEntropyOp final : public Operator<Context> {
102  public:
103  USE_SIMPLE_CTOR_DTOR(CrossEntropyOp);
104  USE_OPERATOR_CONTEXT_FUNCTIONS;
105  bool RunOnDevice() override;
106 
107  protected:
108  // Input: X, label
109  // Output: Y
110  static constexpr T kLOG_THRESHOLD() {
111  return static_cast<T>(1e-20);
112  }
113 };
114 
115 template <typename T, class Context>
116 class CrossEntropyGradientOp final : public Operator<Context> {
117  public:
118  USE_SIMPLE_CTOR_DTOR(CrossEntropyGradientOp);
119  USE_OPERATOR_CONTEXT_FUNCTIONS;
120  bool RunOnDevice() override;
121 
122  protected:
123  // Input: X, label, dY
124  // Ouptut: dX. There is no gradient with respect to the label.
125  static constexpr T kLOG_THRESHOLD() {
126  return static_cast<T>(1e-20);
127  }
128 };
129 
130 } // namespace caffe2
131 
132 #endif // CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...