Caffe2 - C++ API
A deep learning, cross platform ML framework
sqrt_op.cc
1 #include <Eigen/Core>
2 #include "caffe2/operators/elementwise_op.h"
3 
4 namespace caffe2 {
5 
6 struct SqrtCPUFunctor {
7  template <typename T>
8  inline void
9  operator()(const int n, const T* x, T* y, CPUContext* /*device_context*/) {
10  EigenVectorArrayMap<T>(y, n) = ConstEigenVectorArrayMap<T>(x, n).sqrt();
11  }
12 };
13 
14 REGISTER_CPU_OPERATOR(
15  Sqrt,
17 // Input: X, output: Y
18 OPERATOR_SCHEMA(Sqrt)
19  .NumInputs(1)
20  .NumOutputs(1)
21  .AllowInplace({{0, 0}})
22  .IdenticalTypeAndShape()
23  .SetDoc(R"DOC(
24 Computes the element-wise sqrt of the input.
25 )DOC")
26  .Input(0, "X", "ND input tensor")
27  .Output(0, "Y", "ND input tensor");
28 
30  using GradientMakerBase::GradientMakerBase;
31  vector<OperatorDef> GetGradientDefs() override {
32  Argument scale_arg;
33  scale_arg.set_name("scale");
34  scale_arg.set_f(0.5);
35  return vector<OperatorDef>{CreateOperatorDef(
36  "Scale",
37  "",
38  std::vector<string>{GO(0)},
39  std::vector<string>{GI(0)},
40  std::vector<Argument>{scale_arg}),
41  CreateOperatorDef(
42  "Div",
43  "",
44  std::vector<string>{GI(0), O(0)},
45  std::vector<string>{GI(0)})};
46  }
47 };
48 REGISTER_GRADIENT(Sqrt, GetSqrtGradient);
49 } // namespace caffe2
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:66
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...