Caffe2 - C++ API
A deep learning, cross platform ML framework
Related Pages
Modules
Data Structures
Files
C++ API
Python API
GitHub
File List
Globals
caffe2
operators
sqrt_op.cc
1
#include <Eigen/Core>
2
#include "caffe2/operators/elementwise_op.h"
3
4
namespace
caffe2
{
5
6
struct
SqrtCPUFunctor
{
7
template
<
typename
T>
8
inline
void
9
operator()(
const
int
n,
const
T* x, T* y,
CPUContext
*
/*device_context*/
) {
10
EigenVectorArrayMap<T>(y, n) = ConstEigenVectorArrayMap<T>(x, n).sqrt();
11
}
12
};
13
14
REGISTER_CPU_OPERATOR(
15
Sqrt,
16
UnaryElementwiseOp
<
TensorTypes<float>
,
CPUContext
,
SqrtCPUFunctor
>);
17
// Input: X, output: Y
18
OPERATOR_SCHEMA(Sqrt)
19
.NumInputs(1)
20
.NumOutputs(1)
21
.AllowInplace({{0, 0}})
22
.IdenticalTypeAndShape()
23
.SetDoc(R
"DOC(
24
Computes the element-wise sqrt of the input.
25
)DOC")
26
.Input(0,
"X"
,
"ND input tensor"
)
27
.Output(0,
"Y"
,
"ND input tensor"
);
28
29
class
GetSqrtGradient
:
public
GradientMakerBase
{
30
using
GradientMakerBase::GradientMakerBase;
31
vector<OperatorDef> GetGradientDefs()
override
{
32
Argument scale_arg;
33
scale_arg.set_name(
"scale"
);
34
scale_arg.set_f(0.5);
35
return
vector<OperatorDef>{CreateOperatorDef(
36
"Scale"
,
37
""
,
38
std::vector<string>{GO(0)},
39
std::vector<string>{GI(0)},
40
std::vector<Argument>{scale_arg}),
41
CreateOperatorDef(
42
"Div"
,
43
""
,
44
std::vector<string>{GI(0), O(0)},
45
std::vector<string>{GI(0)})};
46
}
47
};
48
REGISTER_GRADIENT(Sqrt,
GetSqrtGradient
);
49
}
// namespace caffe2
caffe2::GradientMakerBase
Definition:
operator_gradient.h:47
caffe2::CPUContext
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition:
context.h:66
caffe2::SqrtCPUFunctor
Definition:
sqrt_op.cc:6
caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition:
convert_encoded_to_raw_leveldb.cc:47
caffe2::TensorTypes
Definition:
operator.h:547
caffe2::UnaryElementwiseWithArgsOp
Definition:
elementwise_op.h:36
caffe2::GetSqrtGradient
Definition:
sqrt_op.cc:29
Generated on Thu Apr 19 2018 13:03:56 for Caffe2 - C++ API by
1.8.11