Caffe2 - C++ API
A deep learning, cross platform ML framework
swish_op.cc
1 #include "swish_op.h"
2 #include "caffe2/core/types.h"
3 #include "caffe2/operators/elementwise_op.h"
4 #include "caffe2/utils/math.h"
5 
6 namespace caffe2 {
8  template <typename T>
9  inline void
10  operator()(const int n, const T* x, T* y, CPUContext* /*device_context*/) {
11  ConstEigenVectorArrayMap<T> xM(x, n);
12  EigenVectorArrayMap<T>(y, n) = xM / (1. + (-xM).exp());
13  }
14 };
15 
16 template <>
17 template <typename T>
19  auto& Xin = Input(X);
20  auto& Yin = Input(Y);
21  auto& DYin = Input(DY);
22  auto* DXout = Output(DX);
23  CAFFE_ENFORCE_EQ(Xin.size(), Yin.size());
24  CAFFE_ENFORCE_EQ(DYin.size(), Yin.size());
25  DXout->ResizeLike(Yin);
26 
27  const float* Xdata = Xin.template data<float>();
28  const float* Ydata = Yin.template data<float>();
29  const float* dYdata = DYin.template data<float>();
30  float* dXdata = DXout->template mutable_data<float>();
31 
32  EigenVectorArrayMap<float> dXvec(dXdata, DXout->size());
33  ConstEigenVectorArrayMap<float> Xvec(Xdata, Xin.size());
34  ConstEigenVectorArrayMap<float> Yvec(Ydata, Yin.size());
35  ConstEigenVectorArrayMap<float> dYvec(dYdata, DYin.size());
36 
37  // dx = dy * (y + sigmoid(x)*(1-y))
38  dXvec = dYvec * (Yvec + (1. / (1. + (-Xvec).exp())) * (1. - Yvec));
39  return true;
40 }
41 
42 REGISTER_CPU_OPERATOR(
43  Swish,
46  CPUContext,
48 REGISTER_CPU_OPERATOR(SwishGradient, SwishGradientOp<CPUContext>);
49 
50 // Input: X, output: Y
51 OPERATOR_SCHEMA(Swish)
52  .NumInputs(1)
53  .NumOutputs(1)
54  .IdenticalTypeAndShape()
55  .SetDoc(R"DOC(
56 Swish takes one input data (Tensor<T>) and produces one output data
57 (Tensor<T>) where the swish function, y = x / (1 + exp(-x)), is applied to the
58 tensor elementwise.
59 )DOC")
60  .Input(0, "X", "1D input tensor")
61  .Output(0, "Y", "1D output tensor");
62 // Input: X, Y, dY, output: dX
63 OPERATOR_SCHEMA(SwishGradient)
64  .NumInputs(3)
65  .NumOutputs(1)
66  .AllowInplace({{2, 0}})
67  .SetDoc(R"DOC(
68 SwishGradient takes X, Y and dY and uses this to update dX according to the
69 chain rule and derivatives of the swish function.
70 )DOC");
71 
72 REGISTER_GRADIENT(Swish, GetSwishGradient);
73 } // namespace caffe2
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:66
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...