Caffe2 - C++ API
A deep learning, cross platform ML framework
sigmoid_op.cc
1 #include "caffe2/operators/elementwise_op.h"
2 #include "caffe2/utils/math.h"
3 
4 namespace caffe2 {
5 
7  template <typename T>
8  inline void
9  operator()(const int n, const T* x, T* y, CPUContext* /*device_context*/) {
10  ConstEigenVectorArrayMap<T> xM(x, n);
11  EigenVectorArrayMap<T>(y, n) = 1. / (1. + (-xM).exp());
12  }
13 };
14 
16  template <typename T>
17  inline void Run(
18  const int n,
19  const T* y,
20  const T* dy,
21  T* dx,
22  CPUContext* /*device_context*/) {
23  ConstEigenVectorArrayMap<T> yM(y, n), dyM(dy, n);
24  EigenVectorArrayMap<T>(dx, n) = dyM * yM * (1. - yM);
25  }
26 };
27 
28 REGISTER_CPU_OPERATOR(
29  Sigmoid, UnaryElementwiseOp<
31 REGISTER_CPU_OPERATOR(
32  SigmoidGradient,
35  CPUContext,
37 
38 // Input: X, output: Y
39 OPERATOR_SCHEMA(Sigmoid)
40  .NumInputs(1)
41  .NumOutputs(1)
42  .AllowInplace({{0, 0}})
43  .IdenticalTypeAndShape()
44  .SetDoc(R"DOC(
45 Sigmoid takes one input data (Tensor<T>) and produces one output data
46 (Tensor<T>) where the sigmoid function, y = 1 / (1 + exp(-x)), is applied to the
47 tensor elementwise.
48 )DOC")
49  .Input(0, "X", "1D input tensor")
50  .Output(0, "Y", "1D output tensor")
51  .InheritOnnxSchema("Sigmoid");
52 // Input: Y, dY, output: dX
53 OPERATOR_SCHEMA(SigmoidGradient)
54  .NumInputs(2)
55  .NumOutputs(1)
56  .AllowInplace({{1, 0}})
57  .SetDoc(R"DOC(
58 SigmoidGradient takes both Y and dY and uses this to update dX according to the
59 chain rule and derivatives of the sigmoid function.
60 )DOC");
61 
63  using GradientMakerBase::GradientMakerBase;
64  vector<OperatorDef> GetGradientDefs() override {
65  return SingleGradientDef(
66  "SigmoidGradient", "",
67  vector<string>{O(0), GO(0)},
68  vector<string>{GI(0)});
69  }
70 };
71 REGISTER_GRADIENT(Sigmoid, GetSigmoidGradient);
72 } // namespace caffe2
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:66
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Performs a binary operation (e.g.